diff --git a/cuburnlib/cuda.py b/cuburnlib/cuda.py index d367a60..2aedc43 100644 --- a/cuburnlib/cuda.py +++ b/cuburnlib/cuda.py @@ -43,6 +43,10 @@ class LaunchContext(object): def threads(self): return reduce(lambda a, b: a*b, self.block + self.grid) + def print_source(self): + print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in + enumerate(self.ptx.source.split('\n'))]) + def compile(self, to_inject={}, verbose=False): inj = dict(to_inject) inj['ctx'] = self @@ -51,10 +55,11 @@ class LaunchContext(object): self.mod = cuda.module_from_buffer(self.ptx.source) except (cuda.CompileError, cuda.RuntimeError), e: print "Aww, dang, compile error. Here's the source:" - print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in - enumerate(self.ptx.source.split('\n'))]) + self.print_source() raise e if verbose: + if verbose >= 3: + self.print_source() for entry in self.ptx.entries: func = self.mod.get_function(entry.entry_name) print "Compiled %s: used %d regs, %d sm, %d local" % ( diff --git a/cuburnlib/ptx.py b/cuburnlib/ptx.py index 6367c40..97c5b56 100644 --- a/cuburnlib/ptx.py +++ b/cuburnlib/ptx.py @@ -74,7 +74,7 @@ from collections import namedtuple def _softjoin(args, sep): """Intersperses 'sep' between 'args' without coercing to string.""" - return [(arg, sep) for arg in args[:-1]] + args[-1:] + return [[arg, sep] for arg in args[:-1]] + list(args[-1:]) BlockCtx = namedtuple('BlockCtx', 'locals code injectors') PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent') @@ -292,12 +292,14 @@ class Block(object): return self def __enter__(self): self.block.push_ctx() - self.block.code(op='{', indent=1, semi=False) + self.block.code(op='{', semi=False) + self.block.code(indent=1) if self.comment: self.block.code(op=['// ', self.comment], semi=False) self.comment = None def __exit__(self, exc_type, exc_value, tb): - self.block.code(op='}', indent=-1, semi=False) + self.block.code(indent=-1) + self.block.code(op='}', semi=False) self.block.pop_ctx() class _CallChain(object): @@ -343,7 +345,7 @@ class _RegFactory(_CallChain): type = type[0] names = names.split() regs = map(lambda n: Reg(type, n), names) - self.block.code(op='.reg .' + type, vars=_softjoin(names, ', ')) + self.block.code(op='.reg .' + type, vars=_softjoin(names, ',')) [self.block.inject(r.name, r) for r in regs] class Op(_CallChain): @@ -374,7 +376,7 @@ class Op(_CallChain): pred = ['@', kwargs['ifp']] if 'ifnotp' in kwargs: pred = ['@!', kwargs['ifnotp']] - self.block.code(pred, '.'.join(op), _softjoin(args, ', ')) + self.block.code(pred, '.'.join(op), _softjoin(args, ',')) class Mem(object): """ @@ -408,11 +410,8 @@ class Mem(object): >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg)) """ assert len(args) >= 2, "vector loads make no sense with < 2 args" - joined = _softjoin(args, ', ') - # This makes a list like [('{', ('arg1', ', ')), ('argf', '}')], which - # when compacted comes out like ['{arg1,', 'argf}'], which lets the - # formatter work properly. There's no semantic value to this. - return [('{', joined[0])] + joined[1:-1] + [(joined[-1], '}')] + # TODO: fix the way this looks (not so easy) + return ['{', _softjoin(args, ','), '}'] @staticmethod def addr(areg, aoffset=''): @@ -786,33 +785,49 @@ class PTXModule(object): raise ValueError("Too many recompiles scheduled!") self.__needs_recompilation = True +def _flatten(val): + if isinstance(val, (list, tuple)): + return ''.join(map(_flatten, val)) + return str(val) + class PTXFormatter(object): """ Formats PTXStmt items into beautiful code. Well, the beautiful part is postponed for now. """ - def __init__(self, indent=4): - self.indent_amt = 4 - def _flatten(self, val): - if isinstance(val, (list, tuple)): - return ''.join(map(self._flatten, val)) - return str(val) + def __init__(self, indent_amt=2, oplen_max=20, varlen_max=12): + self.idamt, self.opm, self.vm = indent_amt, oplen_max, varlen_max def format(self, code): out = [] indent = 0 - for (pfx, op, vars, semi, indent_change) in code: - pfx = self._flatten(pfx) - op = self._flatten(op) - vars = map(self._flatten, vars) - if indent_change < 0: - indent = max(0, indent + self.indent_amt * indent_change) - # TODO: make this a lot prettier - line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars)) - if semi: - line = line.rstrip() + ';' - out.append(line) - if indent_change > 0: - indent += self.indent_amt * indent_change + idx = 0 + while idx < len(code): + opl, vl = 0, 0 + flat = [] + while idx < len(code): + pfx, op, vars, semi, indent_change = code[idx] + idx += 1 + if indent_change: break + pfx, op = _flatten(pfx), _flatten(op) + vars = map(_flatten, vars) + if len(op) <= self.opm: + opl = max(opl, len(op)+2) + for var in vars: + if len(var) <= self.vm: + vl = max(vl, len(var)+1) + flat.append((pfx, op, vars, semi)) + for pfx, op, vars, semi in flat: + if pfx: + line = ('%%-%ds ' % (indent-1)) % pfx + else: + line = ' '*indent + line = ('%%-%ds ' % (indent+opl)) % (line+op) + for i, var in enumerate(vars): + line = ('%%-%ds ' % (indent+opl+vl*(i+1))) % (line+var) + if semi: + line = line.rstrip() + ';' + out.append(line) + indent = max(0, indent + self.idamt * indent_change) return '\n'.join(out) _TExp = namedtuple('_TExp', 'type exprlist') diff --git a/main.py b/main.py index 9ec52f1..99aed1a 100644 --- a/main.py +++ b/main.py @@ -21,12 +21,15 @@ from fr0stlib.pyflam3 import * from fr0stlib.pyflam3._flam3 import * from cuburnlib.render import * -def main(genome_path): +def main(args): + verbose = 1 + if '-d' in args: + verbose = 3 ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True) - ctx.compile(verbose=True) + ctx.compile(verbose=verbose) ctx.run_tests() - with open(genome_path) as fp: + with open(args[-1]) as fp: genomes = Genome.from_string(fp.read()) render = Render(genomes) render.render_frame() @@ -52,8 +55,8 @@ def main(genome_path): #pyglet.app.run() if __name__ == "__main__": - if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]): - print "First argument must be a path to a genome file" + if len(sys.argv) < 2 or not os.path.isfile(sys.argv[-1]): + print "Last argument must be a path to a genome file" sys.exit(1) - main(sys.argv[1]) + main(sys.argv)