Formatter improvements

This commit is contained in:
Steven Robertson 2010-09-02 16:12:22 -04:00
parent 731c637f80
commit a23a493d68
3 changed files with 60 additions and 37 deletions

View File

@ -43,6 +43,10 @@ class LaunchContext(object):
def threads(self): def threads(self):
return reduce(lambda a, b: a*b, self.block + self.grid) return reduce(lambda a, b: a*b, self.block + self.grid)
def print_source(self):
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
enumerate(self.ptx.source.split('\n'))])
def compile(self, to_inject={}, verbose=False): def compile(self, to_inject={}, verbose=False):
inj = dict(to_inject) inj = dict(to_inject)
inj['ctx'] = self inj['ctx'] = self
@ -51,10 +55,11 @@ class LaunchContext(object):
self.mod = cuda.module_from_buffer(self.ptx.source) self.mod = cuda.module_from_buffer(self.ptx.source)
except (cuda.CompileError, cuda.RuntimeError), e: except (cuda.CompileError, cuda.RuntimeError), e:
print "Aww, dang, compile error. Here's the source:" print "Aww, dang, compile error. Here's the source:"
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in self.print_source()
enumerate(self.ptx.source.split('\n'))])
raise e raise e
if verbose: if verbose:
if verbose >= 3:
self.print_source()
for entry in self.ptx.entries: for entry in self.ptx.entries:
func = self.mod.get_function(entry.entry_name) func = self.mod.get_function(entry.entry_name)
print "Compiled %s: used %d regs, %d sm, %d local" % ( print "Compiled %s: used %d regs, %d sm, %d local" % (

View File

@ -74,7 +74,7 @@ from collections import namedtuple
def _softjoin(args, sep): def _softjoin(args, sep):
"""Intersperses 'sep' between 'args' without coercing to string.""" """Intersperses 'sep' between 'args' without coercing to string."""
return [(arg, sep) for arg in args[:-1]] + args[-1:] return [[arg, sep] for arg in args[:-1]] + list(args[-1:])
BlockCtx = namedtuple('BlockCtx', 'locals code injectors') BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent') PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
@ -292,12 +292,14 @@ class Block(object):
return self return self
def __enter__(self): def __enter__(self):
self.block.push_ctx() self.block.push_ctx()
self.block.code(op='{', indent=1, semi=False) self.block.code(op='{', semi=False)
self.block.code(indent=1)
if self.comment: if self.comment:
self.block.code(op=['// ', self.comment], semi=False) self.block.code(op=['// ', self.comment], semi=False)
self.comment = None self.comment = None
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
self.block.code(op='}', indent=-1, semi=False) self.block.code(indent=-1)
self.block.code(op='}', semi=False)
self.block.pop_ctx() self.block.pop_ctx()
class _CallChain(object): class _CallChain(object):
@ -343,7 +345,7 @@ class _RegFactory(_CallChain):
type = type[0] type = type[0]
names = names.split() names = names.split()
regs = map(lambda n: Reg(type, n), names) regs = map(lambda n: Reg(type, n), names)
self.block.code(op='.reg .' + type, vars=_softjoin(names, ', ')) self.block.code(op='.reg .' + type, vars=_softjoin(names, ','))
[self.block.inject(r.name, r) for r in regs] [self.block.inject(r.name, r) for r in regs]
class Op(_CallChain): class Op(_CallChain):
@ -374,7 +376,7 @@ class Op(_CallChain):
pred = ['@', kwargs['ifp']] pred = ['@', kwargs['ifp']]
if 'ifnotp' in kwargs: if 'ifnotp' in kwargs:
pred = ['@!', kwargs['ifnotp']] pred = ['@!', kwargs['ifnotp']]
self.block.code(pred, '.'.join(op), _softjoin(args, ', ')) self.block.code(pred, '.'.join(op), _softjoin(args, ','))
class Mem(object): class Mem(object):
""" """
@ -408,11 +410,8 @@ class Mem(object):
>>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg)) >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
""" """
assert len(args) >= 2, "vector loads make no sense with < 2 args" assert len(args) >= 2, "vector loads make no sense with < 2 args"
joined = _softjoin(args, ', ') # TODO: fix the way this looks (not so easy)
# This makes a list like [('{', ('arg1', ', ')), ('argf', '}')], which return ['{', _softjoin(args, ','), '}']
# when compacted comes out like ['{arg1,', 'argf}'], which lets the
# formatter work properly. There's no semantic value to this.
return [('{', joined[0])] + joined[1:-1] + [(joined[-1], '}')]
@staticmethod @staticmethod
def addr(areg, aoffset=''): def addr(areg, aoffset=''):
@ -786,33 +785,49 @@ class PTXModule(object):
raise ValueError("Too many recompiles scheduled!") raise ValueError("Too many recompiles scheduled!")
self.__needs_recompilation = True self.__needs_recompilation = True
def _flatten(val):
if isinstance(val, (list, tuple)):
return ''.join(map(_flatten, val))
return str(val)
class PTXFormatter(object): class PTXFormatter(object):
""" """
Formats PTXStmt items into beautiful code. Well, the beautiful part is Formats PTXStmt items into beautiful code. Well, the beautiful part is
postponed for now. postponed for now.
""" """
def __init__(self, indent=4): def __init__(self, indent_amt=2, oplen_max=20, varlen_max=12):
self.indent_amt = 4 self.idamt, self.opm, self.vm = indent_amt, oplen_max, varlen_max
def _flatten(self, val):
if isinstance(val, (list, tuple)):
return ''.join(map(self._flatten, val))
return str(val)
def format(self, code): def format(self, code):
out = [] out = []
indent = 0 indent = 0
for (pfx, op, vars, semi, indent_change) in code: idx = 0
pfx = self._flatten(pfx) while idx < len(code):
op = self._flatten(op) opl, vl = 0, 0
vars = map(self._flatten, vars) flat = []
if indent_change < 0: while idx < len(code):
indent = max(0, indent + self.indent_amt * indent_change) pfx, op, vars, semi, indent_change = code[idx]
# TODO: make this a lot prettier idx += 1
line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars)) if indent_change: break
if semi: pfx, op = _flatten(pfx), _flatten(op)
line = line.rstrip() + ';' vars = map(_flatten, vars)
out.append(line) if len(op) <= self.opm:
if indent_change > 0: opl = max(opl, len(op)+2)
indent += self.indent_amt * indent_change for var in vars:
if len(var) <= self.vm:
vl = max(vl, len(var)+1)
flat.append((pfx, op, vars, semi))
for pfx, op, vars, semi in flat:
if pfx:
line = ('%%-%ds ' % (indent-1)) % pfx
else:
line = ' '*indent
line = ('%%-%ds ' % (indent+opl)) % (line+op)
for i, var in enumerate(vars):
line = ('%%-%ds ' % (indent+opl+vl*(i+1))) % (line+var)
if semi:
line = line.rstrip() + ';'
out.append(line)
indent = max(0, indent + self.idamt * indent_change)
return '\n'.join(out) return '\n'.join(out)
_TExp = namedtuple('_TExp', 'type exprlist') _TExp = namedtuple('_TExp', 'type exprlist')

15
main.py
View File

@ -21,12 +21,15 @@ from fr0stlib.pyflam3 import *
from fr0stlib.pyflam3._flam3 import * from fr0stlib.pyflam3._flam3 import *
from cuburnlib.render import * from cuburnlib.render import *
def main(genome_path): def main(args):
verbose = 1
if '-d' in args:
verbose = 3
ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True) ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
ctx.compile(verbose=True) ctx.compile(verbose=verbose)
ctx.run_tests() ctx.run_tests()
with open(genome_path) as fp: with open(args[-1]) as fp:
genomes = Genome.from_string(fp.read()) genomes = Genome.from_string(fp.read())
render = Render(genomes) render = Render(genomes)
render.render_frame() render.render_frame()
@ -52,8 +55,8 @@ def main(genome_path):
#pyglet.app.run() #pyglet.app.run()
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]): if len(sys.argv) < 2 or not os.path.isfile(sys.argv[-1]):
print "First argument must be a path to a genome file" print "Last argument must be a path to a genome file"
sys.exit(1) sys.exit(1)
main(sys.argv[1]) main(sys.argv)