mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Formatter improvements
This commit is contained in:
		@ -43,6 +43,10 @@ class LaunchContext(object):
 | 
				
			|||||||
    def threads(self):
 | 
					    def threads(self):
 | 
				
			||||||
        return reduce(lambda a, b: a*b, self.block + self.grid)
 | 
					        return reduce(lambda a, b: a*b, self.block + self.grid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def print_source(self):
 | 
				
			||||||
 | 
					        print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
 | 
				
			||||||
 | 
					                        enumerate(self.ptx.source.split('\n'))])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compile(self, to_inject={}, verbose=False):
 | 
					    def compile(self, to_inject={}, verbose=False):
 | 
				
			||||||
        inj = dict(to_inject)
 | 
					        inj = dict(to_inject)
 | 
				
			||||||
        inj['ctx'] = self
 | 
					        inj['ctx'] = self
 | 
				
			||||||
@ -51,10 +55,11 @@ class LaunchContext(object):
 | 
				
			|||||||
            self.mod = cuda.module_from_buffer(self.ptx.source)
 | 
					            self.mod = cuda.module_from_buffer(self.ptx.source)
 | 
				
			||||||
        except (cuda.CompileError, cuda.RuntimeError), e:
 | 
					        except (cuda.CompileError, cuda.RuntimeError), e:
 | 
				
			||||||
            print "Aww, dang, compile error. Here's the source:"
 | 
					            print "Aww, dang, compile error. Here's the source:"
 | 
				
			||||||
            print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
 | 
					            self.print_source()
 | 
				
			||||||
                            enumerate(self.ptx.source.split('\n'))])
 | 
					 | 
				
			||||||
            raise e
 | 
					            raise e
 | 
				
			||||||
        if verbose:
 | 
					        if verbose:
 | 
				
			||||||
 | 
					            if verbose >= 3:
 | 
				
			||||||
 | 
					                self.print_source()
 | 
				
			||||||
            for entry in self.ptx.entries:
 | 
					            for entry in self.ptx.entries:
 | 
				
			||||||
                func = self.mod.get_function(entry.entry_name)
 | 
					                func = self.mod.get_function(entry.entry_name)
 | 
				
			||||||
                print "Compiled %s: used %d regs, %d sm, %d local" % (
 | 
					                print "Compiled %s: used %d regs, %d sm, %d local" % (
 | 
				
			||||||
 | 
				
			|||||||
@ -74,7 +74,7 @@ from collections import namedtuple
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def _softjoin(args, sep):
 | 
					def _softjoin(args, sep):
 | 
				
			||||||
    """Intersperses 'sep' between 'args' without coercing to string."""
 | 
					    """Intersperses 'sep' between 'args' without coercing to string."""
 | 
				
			||||||
    return [(arg, sep) for arg in args[:-1]] + args[-1:]
 | 
					    return [[arg, sep] for arg in args[:-1]] + list(args[-1:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
 | 
					BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
 | 
				
			||||||
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
 | 
					PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
 | 
				
			||||||
@ -292,12 +292,14 @@ class Block(object):
 | 
				
			|||||||
        return self
 | 
					        return self
 | 
				
			||||||
    def __enter__(self):
 | 
					    def __enter__(self):
 | 
				
			||||||
        self.block.push_ctx()
 | 
					        self.block.push_ctx()
 | 
				
			||||||
        self.block.code(op='{', indent=1, semi=False)
 | 
					        self.block.code(op='{', semi=False)
 | 
				
			||||||
 | 
					        self.block.code(indent=1)
 | 
				
			||||||
        if self.comment:
 | 
					        if self.comment:
 | 
				
			||||||
            self.block.code(op=['// ', self.comment], semi=False)
 | 
					            self.block.code(op=['// ', self.comment], semi=False)
 | 
				
			||||||
        self.comment = None
 | 
					        self.comment = None
 | 
				
			||||||
    def __exit__(self, exc_type, exc_value, tb):
 | 
					    def __exit__(self, exc_type, exc_value, tb):
 | 
				
			||||||
        self.block.code(op='}', indent=-1, semi=False)
 | 
					        self.block.code(indent=-1)
 | 
				
			||||||
 | 
					        self.block.code(op='}', semi=False)
 | 
				
			||||||
        self.block.pop_ctx()
 | 
					        self.block.pop_ctx()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _CallChain(object):
 | 
					class _CallChain(object):
 | 
				
			||||||
@ -343,7 +345,7 @@ class _RegFactory(_CallChain):
 | 
				
			|||||||
        type = type[0]
 | 
					        type = type[0]
 | 
				
			||||||
        names = names.split()
 | 
					        names = names.split()
 | 
				
			||||||
        regs = map(lambda n: Reg(type, n), names)
 | 
					        regs = map(lambda n: Reg(type, n), names)
 | 
				
			||||||
        self.block.code(op='.reg .' + type, vars=_softjoin(names, ', '))
 | 
					        self.block.code(op='.reg .' + type, vars=_softjoin(names, ','))
 | 
				
			||||||
        [self.block.inject(r.name, r) for r in regs]
 | 
					        [self.block.inject(r.name, r) for r in regs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Op(_CallChain):
 | 
					class Op(_CallChain):
 | 
				
			||||||
@ -374,7 +376,7 @@ class Op(_CallChain):
 | 
				
			|||||||
            pred = ['@', kwargs['ifp']]
 | 
					            pred = ['@', kwargs['ifp']]
 | 
				
			||||||
        if 'ifnotp' in kwargs:
 | 
					        if 'ifnotp' in kwargs:
 | 
				
			||||||
            pred = ['@!', kwargs['ifnotp']]
 | 
					            pred = ['@!', kwargs['ifnotp']]
 | 
				
			||||||
        self.block.code(pred, '.'.join(op), _softjoin(args, ', '))
 | 
					        self.block.code(pred, '.'.join(op), _softjoin(args, ','))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Mem(object):
 | 
					class Mem(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@ -408,11 +410,8 @@ class Mem(object):
 | 
				
			|||||||
        >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
 | 
					        >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert len(args) >= 2, "vector loads make no sense with < 2 args"
 | 
					        assert len(args) >= 2, "vector loads make no sense with < 2 args"
 | 
				
			||||||
        joined = _softjoin(args, ', ')
 | 
					        # TODO: fix the way this looks (not so easy)
 | 
				
			||||||
        # This makes a list like [('{', ('arg1', ', ')), ('argf', '}')], which
 | 
					        return ['{', _softjoin(args, ','), '}']
 | 
				
			||||||
        # when compacted comes out like ['{arg1,', 'argf}'], which lets the
 | 
					 | 
				
			||||||
        # formatter work properly. There's no semantic value to this.
 | 
					 | 
				
			||||||
        return [('{', joined[0])] + joined[1:-1] + [(joined[-1], '}')]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def addr(areg, aoffset=''):
 | 
					    def addr(areg, aoffset=''):
 | 
				
			||||||
@ -786,33 +785,49 @@ class PTXModule(object):
 | 
				
			|||||||
                raise ValueError("Too many recompiles scheduled!")
 | 
					                raise ValueError("Too many recompiles scheduled!")
 | 
				
			||||||
            self.__needs_recompilation = True
 | 
					            self.__needs_recompilation = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _flatten(val):
 | 
				
			||||||
 | 
					    if isinstance(val, (list, tuple)):
 | 
				
			||||||
 | 
					        return ''.join(map(_flatten, val))
 | 
				
			||||||
 | 
					    return str(val)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PTXFormatter(object):
 | 
					class PTXFormatter(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Formats PTXStmt items into beautiful code. Well, the beautiful part is
 | 
					    Formats PTXStmt items into beautiful code. Well, the beautiful part is
 | 
				
			||||||
    postponed for now.
 | 
					    postponed for now.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, indent=4):
 | 
					    def __init__(self, indent_amt=2, oplen_max=20, varlen_max=12):
 | 
				
			||||||
        self.indent_amt = 4
 | 
					        self.idamt, self.opm, self.vm = indent_amt, oplen_max, varlen_max
 | 
				
			||||||
    def _flatten(self, val):
 | 
					 | 
				
			||||||
        if isinstance(val, (list, tuple)):
 | 
					 | 
				
			||||||
            return ''.join(map(self._flatten, val))
 | 
					 | 
				
			||||||
        return str(val)
 | 
					 | 
				
			||||||
    def format(self, code):
 | 
					    def format(self, code):
 | 
				
			||||||
        out = []
 | 
					        out = []
 | 
				
			||||||
        indent = 0
 | 
					        indent = 0
 | 
				
			||||||
        for (pfx, op, vars, semi, indent_change) in code:
 | 
					        idx = 0
 | 
				
			||||||
            pfx = self._flatten(pfx)
 | 
					        while idx < len(code):
 | 
				
			||||||
            op = self._flatten(op)
 | 
					            opl, vl = 0, 0
 | 
				
			||||||
            vars = map(self._flatten, vars)
 | 
					            flat = []
 | 
				
			||||||
            if indent_change < 0:
 | 
					            while idx < len(code):
 | 
				
			||||||
                indent = max(0, indent + self.indent_amt * indent_change)
 | 
					                pfx, op, vars, semi, indent_change = code[idx]
 | 
				
			||||||
            # TODO: make this a lot prettier
 | 
					                idx += 1
 | 
				
			||||||
            line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars))
 | 
					                if indent_change: break
 | 
				
			||||||
            if semi:
 | 
					                pfx, op = _flatten(pfx), _flatten(op)
 | 
				
			||||||
                line = line.rstrip() + ';'
 | 
					                vars = map(_flatten, vars)
 | 
				
			||||||
            out.append(line)
 | 
					                if len(op) <= self.opm:
 | 
				
			||||||
            if indent_change > 0:
 | 
					                    opl = max(opl, len(op)+2)
 | 
				
			||||||
                indent += self.indent_amt * indent_change
 | 
					                for var in vars:
 | 
				
			||||||
 | 
					                    if len(var) <= self.vm:
 | 
				
			||||||
 | 
					                        vl = max(vl, len(var)+1)
 | 
				
			||||||
 | 
					                flat.append((pfx, op, vars, semi))
 | 
				
			||||||
 | 
					            for pfx, op, vars, semi in flat:
 | 
				
			||||||
 | 
					                if pfx:
 | 
				
			||||||
 | 
					                    line = ('%%-%ds ' % (indent-1)) % pfx
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    line = ' '*indent
 | 
				
			||||||
 | 
					                line = ('%%-%ds ' % (indent+opl)) % (line+op)
 | 
				
			||||||
 | 
					                for i, var in enumerate(vars):
 | 
				
			||||||
 | 
					                    line = ('%%-%ds ' % (indent+opl+vl*(i+1))) % (line+var)
 | 
				
			||||||
 | 
					                if semi:
 | 
				
			||||||
 | 
					                    line = line.rstrip() + ';'
 | 
				
			||||||
 | 
					                out.append(line)
 | 
				
			||||||
 | 
					            indent = max(0, indent + self.idamt * indent_change)
 | 
				
			||||||
        return '\n'.join(out)
 | 
					        return '\n'.join(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_TExp = namedtuple('_TExp', 'type exprlist')
 | 
					_TExp = namedtuple('_TExp', 'type exprlist')
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										15
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								main.py
									
									
									
									
									
								
							@ -21,12 +21,15 @@ from fr0stlib.pyflam3 import *
 | 
				
			|||||||
from fr0stlib.pyflam3._flam3 import *
 | 
					from fr0stlib.pyflam3._flam3 import *
 | 
				
			||||||
from cuburnlib.render import *
 | 
					from cuburnlib.render import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(genome_path):
 | 
					def main(args):
 | 
				
			||||||
 | 
					    verbose = 1
 | 
				
			||||||
 | 
					    if '-d' in args:
 | 
				
			||||||
 | 
					        verbose = 3
 | 
				
			||||||
    ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
 | 
					    ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
 | 
				
			||||||
    ctx.compile(verbose=True)
 | 
					    ctx.compile(verbose=verbose)
 | 
				
			||||||
    ctx.run_tests()
 | 
					    ctx.run_tests()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(genome_path) as fp:
 | 
					    with open(args[-1]) as fp:
 | 
				
			||||||
        genomes = Genome.from_string(fp.read())
 | 
					        genomes = Genome.from_string(fp.read())
 | 
				
			||||||
    render = Render(genomes)
 | 
					    render = Render(genomes)
 | 
				
			||||||
    render.render_frame()
 | 
					    render.render_frame()
 | 
				
			||||||
@ -52,8 +55,8 @@ def main(genome_path):
 | 
				
			|||||||
    #pyglet.app.run()
 | 
					    #pyglet.app.run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]):
 | 
					    if len(sys.argv) < 2 or not os.path.isfile(sys.argv[-1]):
 | 
				
			||||||
        print "First argument must be a path to a genome file"
 | 
					        print "Last argument must be a path to a genome file"
 | 
				
			||||||
        sys.exit(1)
 | 
					        sys.exit(1)
 | 
				
			||||||
    main(sys.argv[1])
 | 
					    main(sys.argv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user