From a3660ec6e418972afcf84f0041a4217ceabe817e Mon Sep 17 00:00:00 2001 From: Steven Robertson Date: Wed, 1 Sep 2010 21:09:40 -0400 Subject: [PATCH] PTX DSL working, at least well enough to pass MWCRNGTest --- cuburnlib/cuda.py | 21 +- cuburnlib/device_code.py | 33 ++- cuburnlib/ptx.py | 549 ++++++++++++++++++++++++--------------- main.py | 10 +- 4 files changed, 385 insertions(+), 228 deletions(-) diff --git a/cuburnlib/cuda.py b/cuburnlib/cuda.py index c094c42..d681045 100644 --- a/cuburnlib/cuda.py +++ b/cuburnlib/cuda.py @@ -10,7 +10,7 @@ import pycuda.gl.autoinit import numpy as np -from cuburnlib.ptx import PTXAssembler +from cuburnlib.ptx import PTXModule class LaunchContext(object): """ @@ -44,8 +44,10 @@ class LaunchContext(object): def threads(self): return reduce(lambda a, b: a*b, self.block + self.grid) - def compile(self, verbose=False): - self.ptx = PTXAssembler(self, self.entry_types, self.build_tests) + def compile(self, to_inject={}, verbose=False): + inj = dict(to_inject) + inj['ctx'] = self + self.ptx = PTXModule(self.entry_types, inj, self.build_tests) try: self.mod = cuda.module_from_buffer(self.ptx.source) except (cuda.CompileError, cuda.RuntimeError), e: @@ -54,15 +56,16 @@ class LaunchContext(object): enumerate(self.ptx.source.split('\n'))]) raise e if verbose: - for name in self.ptx.entry_names.values(): - func = self.mod.get_function(name) - print "Compiled %s: used %d regs, %d sm, %d local" % (func, - func.num_regs, func.shared_size_bytes, func.local_size_bytes) + for entry in self.ptx.entries: + func = self.mod.get_function(entry.entry_name) + print "Compiled %s: used %d regs, %d sm, %d local" % ( + entry.entry_name, func.num_regs, + func.shared_size_bytes, func.local_size_bytes) def set_up(self): for inst in self.ptx.deporder(self.ptx.instances.values(), - self.ptx.instances, self): - inst.set_up(self) + self.ptx.instances): + inst.device_init(self) def run(self): if not self.setup_done: self.set_up() diff --git a/cuburnlib/device_code.py b/cuburnlib/device_code.py index 1bf21ed..439737b 100644 --- a/cuburnlib/device_code.py +++ b/cuburnlib/device_code.py @@ -8,7 +8,7 @@ import time import pycuda.driver as cuda import numpy as np -from cuburnlib.ptx import PTXFragment, PTXEntryPoint, PTXTest +from cuburnlib.ptx import * """ Here's the current draft of the full algorithm implementation. @@ -113,10 +113,12 @@ class MWCRNG(PTXFragment): if not os.path.isfile('primes.bin'): raise EnvironmentError('primes.bin not found') + @ptx_func def module_setup(self): mem.global_.u32('mwc_rng_mults', ctx.threads) - mem.global_.u32('mwc_rng_state', ctx.threads) + mem.global_.u64('mwc_rng_state', ctx.threads) + @ptx_func def entry_setup(self): reg.u32('mwc_st mwc_mult mwc_car') with block('Load MWC multipliers and states'): @@ -130,6 +132,7 @@ class MWCRNG(PTXFragment): op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr)) + @ptx_func def entry_teardown(self): with block('Save MWC states'): reg.u32('mwc_off mwc_addr') @@ -138,15 +141,19 @@ class MWCRNG(PTXFragment): op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car)) + @ptx_func def next_b32(self, dst_reg): with block('Load next random into ' + dst_reg.name): reg.u64('mwc_out') op.cvt.u64.u32(mwc_out, mwc_car) - mad.wide.u32(mwc_out, mwc_st) - mov.b64(vec(mwc_st, mwc_car), mwc_out) - mov.u32(dst_reg, mwc_st) + op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out) + op.mov.b64(vec(mwc_st, mwc_car), mwc_out) + op.mov.u32(dst_reg, mwc_st) - def set_up(self, ctx): + def to_inject(self): + return dict(mwc_next_b32=self.next_b32) + + def device_init(self, ctx): if self.threads_ready >= ctx.threads: # Already set up enough random states, don't push again return @@ -168,21 +175,25 @@ class MWCRNG(PTXFragment): states = np.array(ctx.rand.randint(1, 0xffffffff, size=2*ctx.threads), dtype=np.uint32) statedp, statel = ctx.mod.get_global('mwc_rng_state') + print states, len(states.tostring()) cuda.memcpy_htod_async(statedp, states.tostring()) self.threads_ready = ctx.threads - def tests(self, ctx): + def tests(self): return [MWCRNGTest] class MWCRNGTest(PTXTest): name = "MWC RNG sum-of-threads" - deps = [MWCRNG] rounds = 10000 entry_name = 'MWC_RNG_test' entry_params = '' + def deps(self): + return [MWCRNG] + + @ptx_func def module_setup(self): - mem.global_.u64(mwc_rng_test_sums, ctx.threads) + mem.global_.u64('mwc_rng_test_sums', ctx.threads) @ptx_func def entry(self): @@ -191,7 +202,7 @@ class MWCRNGTest(PTXTest): op.mov.u64(sum, 0) with block('Sum next %d random numbers' % self.rounds): reg.u32('loopct') - pred('p') + reg.pred('p') op.mov.u32(loopct, self.rounds) label('loopstart') mwc_next_b32(addend) @@ -206,7 +217,7 @@ class MWCRNGTest(PTXTest): get_gtid(offset) op.mov.u32(adr, mwc_rng_test_sums) op.mad.lo.u32(adr, offset, 8, adr) - st.global_.u64(addr(adr), sum) + op.st.global_.u64(addr(adr), sum) def call(self, ctx): # Get current multipliers and seeds from the device diff --git a/cuburnlib/ptx.py b/cuburnlib/ptx.py index 9deb525..7e67e37 100644 --- a/cuburnlib/ptx.py +++ b/cuburnlib/ptx.py @@ -10,7 +10,8 @@ easier to maintain using this system. # If you see 'import inspect', you know you're in for a good time import inspect -import ctypes +import types +import traceback from collections import namedtuple # Okay, so here's what's going on. @@ -23,7 +24,7 @@ from collections import namedtuple # splitting things up at the level of PTX will greatly reduce performance, as # the cost of accessing the stack, spilling registers, and reloading data from # system memory is unacceptably high even on Fermi GPUs. So we want to split -# code up into functions within Python, but not within the PTX. +# code up into functions within Python, but not within the PTX source. # # The challenge here is variable lifetime. A PTX function might declare a # register at the top of the main block and use it several times throughout the @@ -50,10 +51,10 @@ from collections import namedtuple # reg.u32('hooray_reg') # load_zero(hooray_reg) # -# But using blocks to track state, it would turn in to this ugliness:: +# But using blocks alone to track names, it would turn in to this ugliness:: # # def load_zero(block, dest_reg): -# block.op.mov.u32(op.dest_reg, 0) +# block.op.mov.u32(block.op.dest_reg, 0) # def init_module(): # with Block() as block: # block.regs.hooray_reg = block.reg.u32('hooray_reg') @@ -70,6 +71,9 @@ from collections import namedtuple # below give a clear picture of how to use it, but now you know why this # abomination was crafted to begin with. +def _softjoin(args, sep): + """Intersperses 'sep' between 'args' without coercing to string.""" + return reduce(lambda l, x: l + [x, sep], args, [])[:-1] BlockCtx = namedtuple('BlockCtx', 'locals code injectors') PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent') @@ -100,6 +104,13 @@ class _BlockInjector(object): else: self.inject_into[k] = v self.injected.add(k) + def pop(self, keys): + """Remove keys from a dictionary, as long as we added them.""" + assert not self.dead + for k in keys: + if k in self.injected: + self.inject_into.pop(k) + self.injected.remove(k) def __enter__(self): self.dead = False map(self.inject, self.to_inject.items()) @@ -115,24 +126,49 @@ class _Block(object): For important reasons, the instance must be bound locally as "_block". """ - name = '_block' + name = '_block' # For retrieving from parent scope on first call def __init__(self): + self.reset() + def reset(self): self.outer_ctx = BlockCtx({self.name: self}, [], []) self.stack = [self.outer_ctx] + def clean_injectors(self): + inj = self.stack[-1].injectors + [inj.remove(i) for i in inj if i.dead] def push_ctx(self): - self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], [])) + # Move most recent active injector to new context + self.clean_injectors() + last_inj = self.stack[-1].injectors.pop() + self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], + [last_inj])) def pop_ctx(self): + self.clean_injectors() bs = self.stack.pop() - self.stack[-1].code.append(bs.code) + self.stack[-1].code.extend(bs.code) + if len(self.stack) == 1: + # We're on outer_ctx, so all injectors should be gone + assert len(bs.injectors) == 0, "Injector/context mismatch" + return + # The only injector should be the one added in push_ctx + assert len(bs.injectors) == 1, "Injector/context mismatch" + # Find out which keys were injected while in this context + diff = set(bs.locals.keys()).difference( + set(self.stack[-1].locals.keys())) + # Pop keys and move current injector back down to last context + last_inj = bs.injectors.pop() + last_inj.pop(diff) + self.stack[-1].injectors.append(last_inj) def injector(self, func_globals): - inj = BlockInjector(self.stack[-1].locals, func_globals) + inj = _BlockInjector(dict(self.stack[-1].locals), func_globals) self.stack[-1].injectors.append(inj) return inj def inject(self, name, object): if name in self.stack[-1].locals: - raise KeyError("Duplicate name already exists in this scope.") - self.stack[-1].locals[name] = object - [inj.inject(name, object) for inj in self.stack[-1].injectors] + if self.stack[-1].locals[name] is not object: + raise KeyError("'%s' already exists in this scope." % name) + else: + self.stack[-1].locals[name] = object + [inj.inject(name, object) for inj in self.stack[-1].injectors] def code(self, prefix='', op='', vars=[], semi=True, indent=0): """ Append a PTX statement (or thereabouts) to the current block. @@ -157,7 +193,7 @@ class _Block(object): yes, the only real difference between `prefix`, `op`, and `vars` is in final appearance, but it is in fact quite helpful for debugging. """ - self.stack[-1].append(PTXStmt(prefix, op, vars, indent)) + self.stack[-1].code.append(PTXStmt(prefix, op, vars, semi, indent)) class StrVar(object): """ @@ -168,28 +204,50 @@ class StrVar(object): def __str__(self): return str(val) +class _PTXFuncWrapper(object): + """Enables ptx_func""" + def __init__(self, func): + self.func = func + def __call__(self, *args, **kwargs): + if _Block.name in globals(): + block = globals()['block'] + else: + # Find the '_block' from the enclosing scope + parent = inspect.stack()[2][0] + if _Block.name in parent.f_locals: + block = parent.f_locals[_Block.name] + elif _Block.name in parent.f_globals: + block = parent.f_globals[_Block.name] + else: + # Couldn't find the _block instance. Fail cryptically to + # encourage users to read the source (for now) + raise SyntaxError("Black magic") + # Create a new function with the modified scope and call it. We could + # do this in __init__, but it would hide any changes to globals from + # the module's original scope. Still an option if performance sucks. + newglobals = dict(self.func.func_globals) + func = types.FunctionType(self.func.func_code, newglobals, + self.func.func_name, self.func.func_defaults, + self.func.func_closure) + # TODO: if we generate a new dict every time, we can kill the + # _BlockInjector and move BI.inject() back to _Block, but I don't want + # to delete working code just yet + with block.injector(func.func_globals): + func(*args, **kwargs) + def ptx_func(func): """ Decorator function for code in the DSL. Any function which accesses the DSL namespace, including declared device variables and objects such as "reg" or "op", should be wrapped with this. See Block for some examples. + + Note that writes to global variables will silently fail for now. """ - def ptx_eval(*args, **kwargs): - if self.name not in globals(): - parent = inspect.stack()[-2][0] - if self.name in parent.f_locals: - block = parent.f_locals[self.name] - elif self.name in parent.f_globals: - block = parent.f_globals[self.name] - else: - # Couldn't find the _block instance. Fail cryptically to - # encourage users to read the source (for now) - raise SyntaxError("Black magic") - else: - block = globals()['block'] - with block.injector(func.func_globals): - func(*args, **kwargs) - return ptx_eval + # Attach most of the code to the wrapper class + fw = _PTXFuncWrapper(func) + def wr(*args, **kwargs): + fw(*args, **kwargs) + return wr class Block(object): """ @@ -231,14 +289,17 @@ class Block(object): # `block` is the real _block self.block = block self.comment = None - def __call__(self, comment=None) + def __call__(self, comment=None): self.comment = comment return self def __enter__(self): self.block.push_ctx() - self.block.code(op='{', indent=4) + self.block.code(op='{', indent=1, semi=False) + if self.comment: + self.block.code(op=['// ', self.comment], semi=False) + self.comment = None def __exit__(self, exc_type, exc_value, tb): - self.block.code(op='}', indent=-4) + self.block.code(op='}', indent=-1, semi=False) self.block.pop_ctx() class _CallChain(object): @@ -248,12 +309,12 @@ class _CallChain(object): self.__chain = [] def __call__(self, *args, **kwargs): assert(self.__chain) - self._call(chain, *args, **kwargs) + self._call(self.__chain, *args, **kwargs) self.__chain = [] def __getattr__(self, name): if name == 'global_': name = 'global' - self.chain.append(name) + self.__chain.append(name) # Another great crime against the universe: return self @@ -284,7 +345,7 @@ class _RegFactory(_CallChain): type = type[0] names = names.split() regs = map(lambda n: Reg(type, n), names) - self.block.code(op='.reg .' + type, vars=names) + self.block.code(op='.reg .' + type, vars=_softjoin(names, ', ')) [self.block.inject(r.name, r) for r in regs] # Pending resolution of the op(regs, guard=x) debate @@ -318,8 +379,6 @@ class _RegFactory(_CallChain): #self.name = name #def is_set(self, isnot=False): - - class Op(_CallChain): """ Performs an operation. @@ -340,15 +399,15 @@ class Op(_CallChain): This constructor is available as 'op' in DSL blocks. """ - def _call(self, op, *args, ifp=None, ifnotp=None): + def _call(self, op, *args, **kwargs): pred = '' - if ifp: - if ifnotp: + if 'ifp' in kwargs: + if 'ifnotp' in kwargs: raise SyntaxError("can't use both, fool") - pred = ['@', ifp] - if ifnotp: - pred = ['@!', ifnotp] - self.block.append_code(pred, '.'.join(op), map(str, args)) + pred = ['@', kwargs['ifp']] + if 'ifnotp' in kwargs: + pred = ['@!', kwargs['ifnotp']] + self.block.code(pred, '.'.join(op), _softjoin(args, ', ')) class Mem(object): """ @@ -381,7 +440,7 @@ class Mem(object): >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg)) """ - return ['{', [(a, ', ') for a in args][:-1], '}'] + return ['{', _softjoin(args, ', '), '}'] @staticmethod def addr(areg, aoffset=''): @@ -397,8 +456,7 @@ class _MemFactory(_CallChain): """Actual `mem` object""" def _call(self, type, name, array=False, initializer=None): assert len(type) == 2 - memobj = Mem(type, name, array) - self.dsl.inject(name, memobj) + memobj = Mem(type, name, array, initializer) if array is True: array = ['[]'] elif array: @@ -407,11 +465,12 @@ class _MemFactory(_CallChain): array = [] if initializer: array += [' = ', initializer] - self.block.code(op=['.%s.%s ' % type, name, array]) + self.block.code(op=['.%s.%s ' % (type[0], type[1]), name, array]) + self.block.inject(name, memobj) class Label(object): """ - Specifies the target for a branch. Scoped in PTX? TODO: test. + Specifies the target for a branch. Scoped in PTX? TODO: test that it is. >>> label('infinite_loop') >>> op.bra.uni('label') @@ -426,25 +485,7 @@ class _LabelFactory(object): self.block = block def __call__(self, name): self.block.inject(name, Label(name)) - -class PTXFragment(object): - def module_setup(self): - pass - - def entry_setup(self): - pass - - def entry_teardown(self): - pass - - def globals(self): - pass - - def tests(self): - pass - - def device_init(self, ctx): - pass + self.block.code(prefix='%s:' % name, semi=False) class PTXFragment(object): """ @@ -470,12 +511,14 @@ class PTXFragment(object): for successful compilation. Circular dependencies are forbidden, but multi-level dependencies should be fine. """ - return [DeviceHelpers] + return [_PTXStdLib] - def inject(self): + def to_inject(self): """ Returns a dict of items to add to the DSL namespace. The namespace will be assembled in dependency order before any ptx_funcs are called. + + This is only called once per PTXModule (== once per instance). """ return {} @@ -483,8 +526,8 @@ class PTXFragment(object): """ PTX function to declare things at module scope. It's a PTX syntax error to perform operations at this scope, but we don't yet validate that at - the Python level. A module will call this function on all fragments in - dependency order. + the Python level. A module will call this function on all fragments + used in that module in dependency order. If implemented, this function should use an @ptx_func decorator. """ @@ -513,131 +556,47 @@ class PTXFragment(object): """ pass - def tests(self, ctx): + def finalize_code(self): """ - Returns a list of PTXTest classes which will test this fragment. + Called after running all PTX DSL functions, but before code generation, + to allow fragments which postponed variable evaluation (e.g. using + `StrVar`) to fill in the resulting values. Most fragments should not + use this. + + If implemented, this function *may* use an @ptx_func decorator to + access the global DSL scope, but pretty please don't emit any code + while you're in there. + """ + pass + + def tests(self): + """ + Returns a list of PTXTest types which will test this fragment. """ return [] - def set_up(self, ctx): + def device_init(self, ctx): """ - Do start-of-stream initialization, such as copying data to the device. + Do stuff on the host to prepare the device for execution. 'ctx' is a + LaunchContext or similar. This will get called (in dependency order, of + course) *either* before any entry point invocation, or before *each* + invocation, I'm not sure which yet. (For now it's "each".) """ pass -class PTXModule(object): - """ - Assembles PTX fragments into a module. - """ - - def __init__(self, entries, inject={}, build_tests=False): - self._block = b = _Block() - self.initial_inject = dict(inject) - self._safeupdate(self.initial_inject, dict(block=Block(b), - mem=_MemFactory(b), reg=_RegFactory(b), op=Op(b), - label=_LabelFactory(b), _block=b) - self.needs_recompilation = True - self.max_compiles = 10 - while self.needs_recompilation: - self.assemble(entries, build_tests) - self.max_compiles -= 1 - - def deporder(self, unsorted_instances, instance_map, ctx): - """ - Do a DFS on PTXFragment dependencies, and return an ordered list of - instances where no fragment depends on any before it in the list. - - `unsorted_instances` is the list of instances to sort. - `instance_map` is a dict of types to instances. - """ - seen = {} - def rec(inst): - if inst in seen: return seen[inst] - deps = filter(lambda d: d is not inst, map(instance_map.get, - callable(inst.deps) and inst.deps(self) or inst.deps)) - return seen.setdefault(inst, 1+max([0]+map(rec, deps))) - map(rec, unsorted_instances) - return sorted(unsorted_instances, key=seen.get) - - def _safeupdate(self, dst, src): - """dst.update(src), but no duplicates allowed""" - non_uniq = [k for k in src if k in dst] - if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key)) - dst.update(src) - - def assemble(self, entries, build_tests): - """ - Build the PTX source for the given set of entries. - """ - # Get a property, dealing with the callable-or-data thing. This is - # cumbersome, but flexible; when finished, it may be simplified. - def pget(prop): - if callable(prop): return prop(ctx) - return prop - - instances = {} - unvisited_entries = list(entries) - entry_names = {} - tests = [] - parsed_entries = [] - while unvisited_entries: - ent = unvisited_entries.pop(0) - seen, unvisited = set(), [ent] - while unvisited: - frag = unvisited.pop(0) - seen.add(frag) - inst = instances.setdefault(frag, frag()) - for dep in pget(inst.deps): - if dep not in seen: - unvisited.append(dep) - if build_tests: - for test in pget(inst.tests): - if test not in tests: - if test not in instances: - unvisited_entries.append(test) - tests.append(test) - - tmpl_namespace = {'ctx': ctx} - entry_start, entry_end = [], [] - for inst in self.deporder(map(instances.get, seen), instances, ctx): - self._safeupdate(tmpl_namespace, pget(inst.subs)) - entry_start.append(pget(inst.entry_start)) - entry_end.append(pget(inst.entry_end)) - entry_start_tmpl = '\n'.join(filter(None, entry_start)) - entry_end_tmpl = '\n'.join(filter(None, reversed(entry_end))) - name, args, body = pget(instances[ent].entry) - tmpl_namespace.update({'_entry_name_': name, '_entry_args_': args, - '_entry_body_': body, '_entry_start_': entry_start_tmpl, - '_entry_end_': entry_end_tmpl}) - - entry_tmpl = (".entry {{ _entry_name_ }} ({{ _entry_args_ }})\n" - "{\n{{_entry_start_}}\n{{_entry_body_}}\n{{_entry_end_}}\n}\n") - parsed_entries.append(multisub(entry_tmpl, tmpl_namespace)) - entry_names[ent] = name - - prelude = [] - tmpl_namespace = {'ctx': ctx} - for inst in self.deporder(instances.values(), instances, ctx): - prelude.append(pget(inst.prelude)) - self._safeupdate(tmpl_namespace, pget(inst.subs)) - tmpl_namespace['_prelude_'] = '\n'.join(filter(None, prelude)) - tmpl_namespace['_entries_'] = '\n\n'.join(parsed_entries) - tmpl = "{{ _prelude_ }}\n{{ _entries_ }}" - - self.entry_names = entry_names - self.source = ppr_ptx(multisub(tmpl, tmpl_namespace)) - self.instances = instances - self.tests = tests - - class PTXEntryPoint(PTXFragment): # Human-readable entry point name name = "" + # Device code entry name + entry_name = "" + # List of (type, name) pairs for entry params, e.g. [('u32', 'thing')] + entry_params = [] def entry(self, ctx): """ - Returns a 3-tuple of (name, args, body), which will be assembled into - a function. + PTX DSL function that comprises the body of the PTX statement. + + Must be implemented and decorated with ptx_func. """ raise NotImplementedError @@ -660,32 +619,216 @@ class PTXTest(PTXEntryPoint): """ pass -class DeviceHelpers(PTXFragment): - def __init__(self): - self._forstack = [] +class _PTXStdLib(PTXFragment): + def __init__(self, block): + # Only module that gets the privilege of seeing 'block' directly. + self.block = block - prelude = ".version 2.1\n.target sm_20\n\n" + def deps(self): + return [] + @ptx_func + def module_setup(self): + # TODO: make this modular, maybe? of course, we'd have to support + # multiple devices first, which we definitely do not yet do + self.block.code(prefix='.version 2.1', semi=False) + self.block.code(prefix='.target sm_20', semi=False) + + @ptx_func def _get_gtid(self, dst): - return "{\n// Load GTID into " + dst + """ - .reg .u16 tmp; - .reg .u32 cta, ncta, tid, gtid; + with block("Load GTID into %s" % str(dst)): + reg.u16('tmp') + reg.u32('cta ncta tid gtid') - mov.u16 tmp, %ctaid.x; - cvt.u32.u16 cta, tmp; - mov.u16 tmp, %ntid.x; - cvt.u32.u16 ncta, tmp; - mul.lo.u32 gtid, cta, ncta; + op.mov.u16(tmp, '%ctaid.x') + op.cvt.u32.u16(cta, tmp) + op.mov.u16(tmp, '%ntid.x') + op.cvt.u32.u16(ncta, tmp) + op.mul.lo.u32(gtid, cta, ncta) - mov.u16 tmp, %tid.x; - cvt.u32.u16 tid, tmp; - add.u32 gtid, gtid, tid; - mov.b32 """ + dst + ", gtid;\n}" + op.mov.u16(tmp, '%tid.x') + op.cvt.u32.u16(tid, tmp) + op.add.u32(gtid, gtid, tid) + op.mov.b32(dst, gtid) - def subs(self, ctx): - return { - 'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32', - 'get_gtid': self._get_gtid - } + def to_inject(self): + return dict( + _block=self.block, + block=Block(self.block), + op=Op(self.block), + reg=_RegFactory(self.block), + mem=_MemFactory(self.block), + addr=Mem.addr, + vec=Mem.vec, + label=_LabelFactory(self.block), + get_gtid=self._get_gtid) + +class PTXModule(object): + """ + Assembles PTX fragments into a module. The following properties are + available: + + `instances`: Mapping of type to instance for the PTXFragments used in + the creation of this PTXModule. + `entries`: List of PTXEntry types in this module, including any tests. + `tests`: List of PTXTest types in this module. + `source`: PTX source code for this module. + """ + max_compiles = 10 + + def __init__(self, entries, inject={}, build_tests=False, formatter=None): + """ + Construct a PTXModule. + + `entries`: List of PTXEntry types to include in this module. + `inject`: Dict of items to inject into the DSL namespace. + `build_tests`: If true, build tests into the module. + `formatter`: PTXFormatter instance, or None to use defaults. + """ + block = _Block() + insts, tests, all_deps, entry_deps = ( + self.deptrace(block, entries, build_tests)) + self.instances = insts + self.tests = tests + + inject = dict(inject) + self._safeupdate(inject, {'module': self}) + for inst in all_deps: + self._safeupdate(inject, inst.to_inject()) + [block.inject(k, v) for k, v in inject.items()] + + self.__needs_recompilation = True + self.compiles = 0 + while self.__needs_recompilation: + self.compiles += 1 + self.__needs_recompilation = False + self.assemble(block, all_deps, entry_deps) + self.instances.pop(_PTXStdLib) + print self.instances + + if not formatter: + formatter = PTXFormatter() + self.source = formatter.format(block.outer_ctx.code) + self.entries = list(set(entries + tests)) + + def deporder(self, unsorted_instances, instance_map): + """ + Do a DFS on PTXFragment dependencies, and return an ordered list of + instances where no fragment depends on any before it in the list. + + `unsorted_instances` is the list of instances to sort. + `instance_map` is a dict of types to instances. + """ + seen = {} + def rec(inst): + if inst in seen: return seen[inst] + if inst is None: return 0 + deps = filter(lambda d: d is not inst, + map(instance_map.get, inst.deps())) + return seen.setdefault(inst, 1+max([0]+map(rec, deps))) + map(rec, unsorted_instances) + return sorted(unsorted_instances, key=seen.get) + + def _safeupdate(self, dst, src): + """dst.update(src), but no duplicates allowed""" + non_uniq = [k for k in src if k in dst] + if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key)) + dst.update(src) + + def deptrace(self, block, entries, build_tests): + instances = {_PTXStdLib: _PTXStdLib(block)} + unvisited_entries = list(entries) + tests = set() + entry_deps = {} + + # For each PTXEntry or PTXTest, use a BFS to recursively find and + # instantiate all fragments that are dependencies. If tests are + # discovered, add those to the list of entries. + while unvisited_entries: + ent = unvisited_entries.pop(0) + seen, unvisited = set(), [ent] + while unvisited: + frag = unvisited.pop(0) + seen.add(frag) + # setdefault doesn't work because of _PTXStdLib + if frag not in instances: + inst = frag() + instances[frag] = inst + else: + inst = instances[frag] + for dep in inst.deps(): + if dep not in seen: + unvisited.append(dep) + if build_tests: + for test in inst.tests(): + if test not in tests: + tests.add(test) + if test not in instances: + unvisisted_entries.append(tests) + # For this entry, store insts of all dependencies in order. + entry_deps[ent] = self.deporder(map(instances.get, seen), + instances) + # Find the order for all dependencies in the program. + all_deps = self.deporder(instances.values(), instances) + + return instances, sorted(tests, key=str), all_deps, entry_deps + + def assemble(self, block, all_deps, entry_deps): + # Rebind to local namespace to allow proper retrieval + _block = block + for inst in all_deps: + inst.module_setup() + + for ent, insts in entry_deps.items(): + # This is kind of hackish compared to everything else + params = [Reg('.param.' + str(type), name) + for (type, name) in ent.entry_params] + _block.code(op='.entry %s ' % ent.entry_name, semi=False, + vars=['(', ['%s %s' % (r.type, r.name) for r in params], ')']) + with Block(_block): + [_block.inject(r.name, r) for r in params] + for dep in insts: + dep.entry_setup() + self.instances[ent].entry() + for dep in reversed(insts): + dep.entry_teardown() + + for inst in all_deps: + inst.finalize_code() + + def set_needs_recompilation(self): + if not self.__needs_recompilation: + if self.compiles >= self.max_compiles: + raise ValueError("Too many recompiles scheduled!") + self.__needs_recompilation = True + +class PTXFormatter(object): + """ + Formats PTXStmt items into beautiful code. Well, the beautiful part is + postponed for now. + """ + def __init__(self, indent=4): + self.indent_amt = 4 + def _flatten(self, val): + if isinstance(val, (list, tuple)): + return ''.join(map(self._flatten, val)) + return str(val) + def format(self, code): + out = [] + indent = 0 + for (pfx, op, vars, semi, indent_change) in code: + pfx = self._flatten(pfx) + op = self._flatten(op) + vars = map(self._flatten, vars) + if indent_change < 0: + indent = max(0, indent + self.indent_amt * indent_change) + # TODO: make this a lot prettier + line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars)) + if semi: + line = line.rstrip() + ';' + out.append(line) + if indent_change > 0: + indent += self.indent_amt * indent_change + return '\n'.join(out) diff --git a/main.py b/main.py index 8bc8636..9ec52f1 100644 --- a/main.py +++ b/main.py @@ -15,16 +15,16 @@ from ctypes import * import numpy as np -#from cuburnlib.device_code import MWCRNGTest -#from cuburnlib.cuda import LaunchContext +from cuburnlib.device_code import MWCRNGTest +from cuburnlib.cuda import LaunchContext from fr0stlib.pyflam3 import * from fr0stlib.pyflam3._flam3 import * from cuburnlib.render import * def main(genome_path): - #ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True) - #ctx.compile(True) - #ctx.run_tests() + ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True) + ctx.compile(verbose=True) + ctx.run_tests() with open(genome_path) as fp: genomes = Genome.from_string(fp.read())