mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Formatter improvements
This commit is contained in:
		@ -43,6 +43,10 @@ class LaunchContext(object):
 | 
			
		||||
    def threads(self):
 | 
			
		||||
        return reduce(lambda a, b: a*b, self.block + self.grid)
 | 
			
		||||
 | 
			
		||||
    def print_source(self):
 | 
			
		||||
        print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
 | 
			
		||||
                        enumerate(self.ptx.source.split('\n'))])
 | 
			
		||||
 | 
			
		||||
    def compile(self, to_inject={}, verbose=False):
 | 
			
		||||
        inj = dict(to_inject)
 | 
			
		||||
        inj['ctx'] = self
 | 
			
		||||
@ -51,10 +55,11 @@ class LaunchContext(object):
 | 
			
		||||
            self.mod = cuda.module_from_buffer(self.ptx.source)
 | 
			
		||||
        except (cuda.CompileError, cuda.RuntimeError), e:
 | 
			
		||||
            print "Aww, dang, compile error. Here's the source:"
 | 
			
		||||
            print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
 | 
			
		||||
                            enumerate(self.ptx.source.split('\n'))])
 | 
			
		||||
            self.print_source()
 | 
			
		||||
            raise e
 | 
			
		||||
        if verbose:
 | 
			
		||||
            if verbose >= 3:
 | 
			
		||||
                self.print_source()
 | 
			
		||||
            for entry in self.ptx.entries:
 | 
			
		||||
                func = self.mod.get_function(entry.entry_name)
 | 
			
		||||
                print "Compiled %s: used %d regs, %d sm, %d local" % (
 | 
			
		||||
 | 
			
		||||
@ -74,7 +74,7 @@ from collections import namedtuple
 | 
			
		||||
 | 
			
		||||
def _softjoin(args, sep):
 | 
			
		||||
    """Intersperses 'sep' between 'args' without coercing to string."""
 | 
			
		||||
    return [(arg, sep) for arg in args[:-1]] + args[-1:]
 | 
			
		||||
    return [[arg, sep] for arg in args[:-1]] + list(args[-1:])
 | 
			
		||||
 | 
			
		||||
BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
 | 
			
		||||
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
 | 
			
		||||
@ -292,12 +292,14 @@ class Block(object):
 | 
			
		||||
        return self
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        self.block.push_ctx()
 | 
			
		||||
        self.block.code(op='{', indent=1, semi=False)
 | 
			
		||||
        self.block.code(op='{', semi=False)
 | 
			
		||||
        self.block.code(indent=1)
 | 
			
		||||
        if self.comment:
 | 
			
		||||
            self.block.code(op=['// ', self.comment], semi=False)
 | 
			
		||||
        self.comment = None
 | 
			
		||||
    def __exit__(self, exc_type, exc_value, tb):
 | 
			
		||||
        self.block.code(op='}', indent=-1, semi=False)
 | 
			
		||||
        self.block.code(indent=-1)
 | 
			
		||||
        self.block.code(op='}', semi=False)
 | 
			
		||||
        self.block.pop_ctx()
 | 
			
		||||
 | 
			
		||||
class _CallChain(object):
 | 
			
		||||
@ -343,7 +345,7 @@ class _RegFactory(_CallChain):
 | 
			
		||||
        type = type[0]
 | 
			
		||||
        names = names.split()
 | 
			
		||||
        regs = map(lambda n: Reg(type, n), names)
 | 
			
		||||
        self.block.code(op='.reg .' + type, vars=_softjoin(names, ', '))
 | 
			
		||||
        self.block.code(op='.reg .' + type, vars=_softjoin(names, ','))
 | 
			
		||||
        [self.block.inject(r.name, r) for r in regs]
 | 
			
		||||
 | 
			
		||||
class Op(_CallChain):
 | 
			
		||||
@ -374,7 +376,7 @@ class Op(_CallChain):
 | 
			
		||||
            pred = ['@', kwargs['ifp']]
 | 
			
		||||
        if 'ifnotp' in kwargs:
 | 
			
		||||
            pred = ['@!', kwargs['ifnotp']]
 | 
			
		||||
        self.block.code(pred, '.'.join(op), _softjoin(args, ', '))
 | 
			
		||||
        self.block.code(pred, '.'.join(op), _softjoin(args, ','))
 | 
			
		||||
 | 
			
		||||
class Mem(object):
 | 
			
		||||
    """
 | 
			
		||||
@ -408,11 +410,8 @@ class Mem(object):
 | 
			
		||||
        >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
 | 
			
		||||
        """
 | 
			
		||||
        assert len(args) >= 2, "vector loads make no sense with < 2 args"
 | 
			
		||||
        joined = _softjoin(args, ', ')
 | 
			
		||||
        # This makes a list like [('{', ('arg1', ', ')), ('argf', '}')], which
 | 
			
		||||
        # when compacted comes out like ['{arg1,', 'argf}'], which lets the
 | 
			
		||||
        # formatter work properly. There's no semantic value to this.
 | 
			
		||||
        return [('{', joined[0])] + joined[1:-1] + [(joined[-1], '}')]
 | 
			
		||||
        # TODO: fix the way this looks (not so easy)
 | 
			
		||||
        return ['{', _softjoin(args, ','), '}']
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def addr(areg, aoffset=''):
 | 
			
		||||
@ -786,33 +785,49 @@ class PTXModule(object):
 | 
			
		||||
                raise ValueError("Too many recompiles scheduled!")
 | 
			
		||||
            self.__needs_recompilation = True
 | 
			
		||||
 | 
			
		||||
def _flatten(val):
 | 
			
		||||
    if isinstance(val, (list, tuple)):
 | 
			
		||||
        return ''.join(map(_flatten, val))
 | 
			
		||||
    return str(val)
 | 
			
		||||
 | 
			
		||||
class PTXFormatter(object):
 | 
			
		||||
    """
 | 
			
		||||
    Formats PTXStmt items into beautiful code. Well, the beautiful part is
 | 
			
		||||
    postponed for now.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, indent=4):
 | 
			
		||||
        self.indent_amt = 4
 | 
			
		||||
    def _flatten(self, val):
 | 
			
		||||
        if isinstance(val, (list, tuple)):
 | 
			
		||||
            return ''.join(map(self._flatten, val))
 | 
			
		||||
        return str(val)
 | 
			
		||||
    def __init__(self, indent_amt=2, oplen_max=20, varlen_max=12):
 | 
			
		||||
        self.idamt, self.opm, self.vm = indent_amt, oplen_max, varlen_max
 | 
			
		||||
    def format(self, code):
 | 
			
		||||
        out = []
 | 
			
		||||
        indent = 0
 | 
			
		||||
        for (pfx, op, vars, semi, indent_change) in code:
 | 
			
		||||
            pfx = self._flatten(pfx)
 | 
			
		||||
            op = self._flatten(op)
 | 
			
		||||
            vars = map(self._flatten, vars)
 | 
			
		||||
            if indent_change < 0:
 | 
			
		||||
                indent = max(0, indent + self.indent_amt * indent_change)
 | 
			
		||||
            # TODO: make this a lot prettier
 | 
			
		||||
            line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars))
 | 
			
		||||
            if semi:
 | 
			
		||||
                line = line.rstrip() + ';'
 | 
			
		||||
            out.append(line)
 | 
			
		||||
            if indent_change > 0:
 | 
			
		||||
                indent += self.indent_amt * indent_change
 | 
			
		||||
        idx = 0
 | 
			
		||||
        while idx < len(code):
 | 
			
		||||
            opl, vl = 0, 0
 | 
			
		||||
            flat = []
 | 
			
		||||
            while idx < len(code):
 | 
			
		||||
                pfx, op, vars, semi, indent_change = code[idx]
 | 
			
		||||
                idx += 1
 | 
			
		||||
                if indent_change: break
 | 
			
		||||
                pfx, op = _flatten(pfx), _flatten(op)
 | 
			
		||||
                vars = map(_flatten, vars)
 | 
			
		||||
                if len(op) <= self.opm:
 | 
			
		||||
                    opl = max(opl, len(op)+2)
 | 
			
		||||
                for var in vars:
 | 
			
		||||
                    if len(var) <= self.vm:
 | 
			
		||||
                        vl = max(vl, len(var)+1)
 | 
			
		||||
                flat.append((pfx, op, vars, semi))
 | 
			
		||||
            for pfx, op, vars, semi in flat:
 | 
			
		||||
                if pfx:
 | 
			
		||||
                    line = ('%%-%ds ' % (indent-1)) % pfx
 | 
			
		||||
                else:
 | 
			
		||||
                    line = ' '*indent
 | 
			
		||||
                line = ('%%-%ds ' % (indent+opl)) % (line+op)
 | 
			
		||||
                for i, var in enumerate(vars):
 | 
			
		||||
                    line = ('%%-%ds ' % (indent+opl+vl*(i+1))) % (line+var)
 | 
			
		||||
                if semi:
 | 
			
		||||
                    line = line.rstrip() + ';'
 | 
			
		||||
                out.append(line)
 | 
			
		||||
            indent = max(0, indent + self.idamt * indent_change)
 | 
			
		||||
        return '\n'.join(out)
 | 
			
		||||
 | 
			
		||||
_TExp = namedtuple('_TExp', 'type exprlist')
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user