mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Formatter improvements
This commit is contained in:
parent
731c637f80
commit
a23a493d68
@ -43,6 +43,10 @@ class LaunchContext(object):
|
|||||||
def threads(self):
|
def threads(self):
|
||||||
return reduce(lambda a, b: a*b, self.block + self.grid)
|
return reduce(lambda a, b: a*b, self.block + self.grid)
|
||||||
|
|
||||||
|
def print_source(self):
|
||||||
|
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
|
||||||
|
enumerate(self.ptx.source.split('\n'))])
|
||||||
|
|
||||||
def compile(self, to_inject={}, verbose=False):
|
def compile(self, to_inject={}, verbose=False):
|
||||||
inj = dict(to_inject)
|
inj = dict(to_inject)
|
||||||
inj['ctx'] = self
|
inj['ctx'] = self
|
||||||
@ -51,10 +55,11 @@ class LaunchContext(object):
|
|||||||
self.mod = cuda.module_from_buffer(self.ptx.source)
|
self.mod = cuda.module_from_buffer(self.ptx.source)
|
||||||
except (cuda.CompileError, cuda.RuntimeError), e:
|
except (cuda.CompileError, cuda.RuntimeError), e:
|
||||||
print "Aww, dang, compile error. Here's the source:"
|
print "Aww, dang, compile error. Here's the source:"
|
||||||
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
|
self.print_source()
|
||||||
enumerate(self.ptx.source.split('\n'))])
|
|
||||||
raise e
|
raise e
|
||||||
if verbose:
|
if verbose:
|
||||||
|
if verbose >= 3:
|
||||||
|
self.print_source()
|
||||||
for entry in self.ptx.entries:
|
for entry in self.ptx.entries:
|
||||||
func = self.mod.get_function(entry.entry_name)
|
func = self.mod.get_function(entry.entry_name)
|
||||||
print "Compiled %s: used %d regs, %d sm, %d local" % (
|
print "Compiled %s: used %d regs, %d sm, %d local" % (
|
||||||
|
@ -74,7 +74,7 @@ from collections import namedtuple
|
|||||||
|
|
||||||
def _softjoin(args, sep):
|
def _softjoin(args, sep):
|
||||||
"""Intersperses 'sep' between 'args' without coercing to string."""
|
"""Intersperses 'sep' between 'args' without coercing to string."""
|
||||||
return [(arg, sep) for arg in args[:-1]] + args[-1:]
|
return [[arg, sep] for arg in args[:-1]] + list(args[-1:])
|
||||||
|
|
||||||
BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
|
BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
|
||||||
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
|
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
|
||||||
@ -292,12 +292,14 @@ class Block(object):
|
|||||||
return self
|
return self
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.block.push_ctx()
|
self.block.push_ctx()
|
||||||
self.block.code(op='{', indent=1, semi=False)
|
self.block.code(op='{', semi=False)
|
||||||
|
self.block.code(indent=1)
|
||||||
if self.comment:
|
if self.comment:
|
||||||
self.block.code(op=['// ', self.comment], semi=False)
|
self.block.code(op=['// ', self.comment], semi=False)
|
||||||
self.comment = None
|
self.comment = None
|
||||||
def __exit__(self, exc_type, exc_value, tb):
|
def __exit__(self, exc_type, exc_value, tb):
|
||||||
self.block.code(op='}', indent=-1, semi=False)
|
self.block.code(indent=-1)
|
||||||
|
self.block.code(op='}', semi=False)
|
||||||
self.block.pop_ctx()
|
self.block.pop_ctx()
|
||||||
|
|
||||||
class _CallChain(object):
|
class _CallChain(object):
|
||||||
@ -343,7 +345,7 @@ class _RegFactory(_CallChain):
|
|||||||
type = type[0]
|
type = type[0]
|
||||||
names = names.split()
|
names = names.split()
|
||||||
regs = map(lambda n: Reg(type, n), names)
|
regs = map(lambda n: Reg(type, n), names)
|
||||||
self.block.code(op='.reg .' + type, vars=_softjoin(names, ', '))
|
self.block.code(op='.reg .' + type, vars=_softjoin(names, ','))
|
||||||
[self.block.inject(r.name, r) for r in regs]
|
[self.block.inject(r.name, r) for r in regs]
|
||||||
|
|
||||||
class Op(_CallChain):
|
class Op(_CallChain):
|
||||||
@ -374,7 +376,7 @@ class Op(_CallChain):
|
|||||||
pred = ['@', kwargs['ifp']]
|
pred = ['@', kwargs['ifp']]
|
||||||
if 'ifnotp' in kwargs:
|
if 'ifnotp' in kwargs:
|
||||||
pred = ['@!', kwargs['ifnotp']]
|
pred = ['@!', kwargs['ifnotp']]
|
||||||
self.block.code(pred, '.'.join(op), _softjoin(args, ', '))
|
self.block.code(pred, '.'.join(op), _softjoin(args, ','))
|
||||||
|
|
||||||
class Mem(object):
|
class Mem(object):
|
||||||
"""
|
"""
|
||||||
@ -408,11 +410,8 @@ class Mem(object):
|
|||||||
>>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
|
>>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
|
||||||
"""
|
"""
|
||||||
assert len(args) >= 2, "vector loads make no sense with < 2 args"
|
assert len(args) >= 2, "vector loads make no sense with < 2 args"
|
||||||
joined = _softjoin(args, ', ')
|
# TODO: fix the way this looks (not so easy)
|
||||||
# This makes a list like [('{', ('arg1', ', ')), ('argf', '}')], which
|
return ['{', _softjoin(args, ','), '}']
|
||||||
# when compacted comes out like ['{arg1,', 'argf}'], which lets the
|
|
||||||
# formatter work properly. There's no semantic value to this.
|
|
||||||
return [('{', joined[0])] + joined[1:-1] + [(joined[-1], '}')]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def addr(areg, aoffset=''):
|
def addr(areg, aoffset=''):
|
||||||
@ -786,33 +785,49 @@ class PTXModule(object):
|
|||||||
raise ValueError("Too many recompiles scheduled!")
|
raise ValueError("Too many recompiles scheduled!")
|
||||||
self.__needs_recompilation = True
|
self.__needs_recompilation = True
|
||||||
|
|
||||||
|
def _flatten(val):
|
||||||
|
if isinstance(val, (list, tuple)):
|
||||||
|
return ''.join(map(_flatten, val))
|
||||||
|
return str(val)
|
||||||
|
|
||||||
class PTXFormatter(object):
|
class PTXFormatter(object):
|
||||||
"""
|
"""
|
||||||
Formats PTXStmt items into beautiful code. Well, the beautiful part is
|
Formats PTXStmt items into beautiful code. Well, the beautiful part is
|
||||||
postponed for now.
|
postponed for now.
|
||||||
"""
|
"""
|
||||||
def __init__(self, indent=4):
|
def __init__(self, indent_amt=2, oplen_max=20, varlen_max=12):
|
||||||
self.indent_amt = 4
|
self.idamt, self.opm, self.vm = indent_amt, oplen_max, varlen_max
|
||||||
def _flatten(self, val):
|
|
||||||
if isinstance(val, (list, tuple)):
|
|
||||||
return ''.join(map(self._flatten, val))
|
|
||||||
return str(val)
|
|
||||||
def format(self, code):
|
def format(self, code):
|
||||||
out = []
|
out = []
|
||||||
indent = 0
|
indent = 0
|
||||||
for (pfx, op, vars, semi, indent_change) in code:
|
idx = 0
|
||||||
pfx = self._flatten(pfx)
|
while idx < len(code):
|
||||||
op = self._flatten(op)
|
opl, vl = 0, 0
|
||||||
vars = map(self._flatten, vars)
|
flat = []
|
||||||
if indent_change < 0:
|
while idx < len(code):
|
||||||
indent = max(0, indent + self.indent_amt * indent_change)
|
pfx, op, vars, semi, indent_change = code[idx]
|
||||||
# TODO: make this a lot prettier
|
idx += 1
|
||||||
line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars))
|
if indent_change: break
|
||||||
|
pfx, op = _flatten(pfx), _flatten(op)
|
||||||
|
vars = map(_flatten, vars)
|
||||||
|
if len(op) <= self.opm:
|
||||||
|
opl = max(opl, len(op)+2)
|
||||||
|
for var in vars:
|
||||||
|
if len(var) <= self.vm:
|
||||||
|
vl = max(vl, len(var)+1)
|
||||||
|
flat.append((pfx, op, vars, semi))
|
||||||
|
for pfx, op, vars, semi in flat:
|
||||||
|
if pfx:
|
||||||
|
line = ('%%-%ds ' % (indent-1)) % pfx
|
||||||
|
else:
|
||||||
|
line = ' '*indent
|
||||||
|
line = ('%%-%ds ' % (indent+opl)) % (line+op)
|
||||||
|
for i, var in enumerate(vars):
|
||||||
|
line = ('%%-%ds ' % (indent+opl+vl*(i+1))) % (line+var)
|
||||||
if semi:
|
if semi:
|
||||||
line = line.rstrip() + ';'
|
line = line.rstrip() + ';'
|
||||||
out.append(line)
|
out.append(line)
|
||||||
if indent_change > 0:
|
indent = max(0, indent + self.idamt * indent_change)
|
||||||
indent += self.indent_amt * indent_change
|
|
||||||
return '\n'.join(out)
|
return '\n'.join(out)
|
||||||
|
|
||||||
_TExp = namedtuple('_TExp', 'type exprlist')
|
_TExp = namedtuple('_TExp', 'type exprlist')
|
||||||
|
15
main.py
15
main.py
@ -21,12 +21,15 @@ from fr0stlib.pyflam3 import *
|
|||||||
from fr0stlib.pyflam3._flam3 import *
|
from fr0stlib.pyflam3._flam3 import *
|
||||||
from cuburnlib.render import *
|
from cuburnlib.render import *
|
||||||
|
|
||||||
def main(genome_path):
|
def main(args):
|
||||||
|
verbose = 1
|
||||||
|
if '-d' in args:
|
||||||
|
verbose = 3
|
||||||
ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
|
ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
|
||||||
ctx.compile(verbose=True)
|
ctx.compile(verbose=verbose)
|
||||||
ctx.run_tests()
|
ctx.run_tests()
|
||||||
|
|
||||||
with open(genome_path) as fp:
|
with open(args[-1]) as fp:
|
||||||
genomes = Genome.from_string(fp.read())
|
genomes = Genome.from_string(fp.read())
|
||||||
render = Render(genomes)
|
render = Render(genomes)
|
||||||
render.render_frame()
|
render.render_frame()
|
||||||
@ -52,8 +55,8 @@ def main(genome_path):
|
|||||||
#pyglet.app.run()
|
#pyglet.app.run()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]):
|
if len(sys.argv) < 2 or not os.path.isfile(sys.argv[-1]):
|
||||||
print "First argument must be a path to a genome file"
|
print "Last argument must be a path to a genome file"
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
main(sys.argv[1])
|
main(sys.argv)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user