diff --git a/cuburnlib/device_code.py b/cuburnlib/device_code.py index 321f0fc..1bf21ed 100644 --- a/cuburnlib/device_code.py +++ b/cuburnlib/device_code.py @@ -1,3 +1,7 @@ +""" +Contains the PTX fragments which will drive the device. +""" + import os import time @@ -6,60 +10,147 @@ import numpy as np from cuburnlib.ptx import PTXFragment, PTXEntryPoint, PTXTest +""" +Here's the current draft of the full algorithm implementation. + +declare xform jump table + +load random state + +clear x_coord, y_coord, z_coord, w_coord; +store -(FUSE+1) to shared (per-warp) num_samples_sh +clear badvals [1] + +load param (global_cp_idx_addr) +index table start (global_cp_idx) [2] +load count of indexes from global cp index => + store to qlocal current_cp_num [3] + +outermost loop start: + load current_cp_num + if current_cp_num <= 0: + exit + + load param global_cp_idx_addr + calculate offset into address with current_cp_num, global_cp_idx_addr + load cp_base_address + stream_start (cp_base, cp_base_addr) [4] + +FUSE_START: + num_samples += 1 + if num_samples >= 0: + # Okay, we're done FUSEing, prepare to enter normal loop + load num_samples => store to shared (per-warp) num_samples + + +ITER_LOOP_START: + reg xform_addr, xform_stream_addr, xform_select + + mwc_next_u32 to xform_select + # Performance test: roll/unroll this loop? + stream_load xform_prob (cp_stream) + if xform_select <= xform_prob: + bra.uni XFORM_1_LBL + ... + stream_load xform_prob (cp_stream) + if xform_select <= xform_prob: + bra.uni XFORM_N_LBL + +XFORM_1_LBL: + stream_load xform_1_ (cp_stream) + ... + bra.uni XFORM_POST + +XFORM_POST: + [if final_xform:] + [do final_xform] + + if num_samples < 0: + # FUSE still in progress + bra.uni FUSE_START + +FRAGMENT_WRITEBACK: + # Unknown at this time. + +SHUFFLE: + # Unknown at this time. + + load num_samples from num_samples_sh + num_samples -= 1 + if num_samples > 0: + bra.uni ITER_LOOP_START + + +[1] Tracking 'badvals' can put a pretty large hit on performance, particularly + for images that sample a small amount of the grid. So this might be cut + when rendering for performance. On the other hand, it might actually help + tune the algorithm later, so it'll definitely be an option. + +[2] Control points for each temporal sample will be preloaded to the + device in the compact DataStream format (more on this later). Their + locations are represented in an index table, which starts with a single + `.u32 length`, followed by `length` pointers. To avoid having to keep + reloading `length`, or worse, using a register to hold it in memory, we + instead count *down* to zero. This is a very common idiom. + +[3] 'qlocal' is quasi-local storage. it could easily be actual local storage, + depending on how local storage is implemented, but the extra 128-byte loads + for such values might make a performance difference. qlocal variables may + be identical across a warp or even a CTA, and so variables noted as + "qlocal" here might end up in shared memory or even a small per-warp or + per-CTA buffer in global memory created specifically for this purpose, + after benchmarking is done. + +[4] DataStreams are "opaque" data serialization structures defined below. The + structure of a stream is actually created while parsing the DSL by the load + statements themselves. Some benchmarks need to be done before DataStreams + stop being "opaque" and become simply "dynamic". +""" + class MWCRNG(PTXFragment): def __init__(self): self.threads_ready = 0 if not os.path.isfile('primes.bin'): raise EnvironmentError('primes.bin not found') - prelude = (".global .u32 mwc_rng_mults[{{ctx.threads}}];\n" - ".global .u64 mwc_rng_state[{{ctx.threads}}];") + def module_setup(self): + mem.global_.u32('mwc_rng_mults', ctx.threads) + mem.global_.u32('mwc_rng_state', ctx.threads) - def _next_b32(self, dreg): - # TODO: make sure PTX optimizes away superfluous move instrs - return """ - { - // MWC next b32 - .reg .u64 mwc_out; - cvt.u64.u32 mwc_out, mwc_car; - mad.wide.u32 mwc_out, mwc_st, mwc_mult, mwc_out; - mov.b64 {mwc_st, mwc_car}, mwc_out; - mov.u32 %s, mwc_st; - } - """ % dreg + def entry_setup(self): + reg.u32('mwc_st mwc_mult mwc_car') + with block('Load MWC multipliers and states'): + reg.u32('mwc_off mwc_addr') + get_gtid(mwc_off) + op.mov.u32(mwc_addr, mwc_rng_mults) + op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr) + op.ld.global_.u32(mwc_mult, addr(mwc_addr)) - def subs(self, ctx): - return {'mwc_next_b32': self._next_b32} + op.mov.u32(mwc_addr, mwc_rng_state) + op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) + op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr)) - entry_start = """ - .reg .u32 mwc_st, mwc_mult, mwc_car; - { - // MWC load multipliers and RNG states - .reg .u32 mwc_off, mwc_addr; - {{ get_gtid('mwc_off') }} - mov.u32 mwc_addr, mwc_rng_mults; - mad.lo.u32 mwc_addr, mwc_off, 4, mwc_addr; - ld.global.u32 mwc_mult, [mwc_addr]; - mov.u32 mwc_addr, mwc_rng_state; - mad.lo.u32 mwc_addr, mwc_off, 8, mwc_addr; - ld.global.v2.u32 {mwc_st, mwc_car}, [mwc_addr]; - } - """ + def entry_teardown(self): + with block('Save MWC states'): + reg.u32('mwc_off mwc_addr') + get_gtid(mwc_off) + op.mov.u32(mwc_addr, mwc_rng_state) + op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) + op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car)) - entry_end = """ - { - // MWC save states - .reg .u32 mwc_addr, mwc_off; - {{ get_gtid('mwc_off') }} - mov.u32 mwc_addr, mwc_rng_state; - mad.lo.u32 mwc_addr, mwc_off, 8, mwc_addr; - st.global.v2.u32 [mwc_addr], {mwc_st, mwc_car}; - } - """ + def next_b32(self, dst_reg): + with block('Load next random into ' + dst_reg.name): + reg.u64('mwc_out') + op.cvt.u64.u32(mwc_out, mwc_car) + mad.wide.u32(mwc_out, mwc_st) + mov.b64(vec(mwc_st, mwc_car), mwc_out) + mov.u32(dst_reg, mwc_st) def set_up(self, ctx): if self.threads_ready >= ctx.threads: + # Already set up enough random states, don't push again return + # Load raw big-endian u32 multipliers from primes.bin. with open('primes.bin') as primefp: dt = np.dtype(np.uint32).newbyteorder('B') @@ -87,34 +178,35 @@ class MWCRNGTest(PTXTest): name = "MWC RNG sum-of-threads" deps = [MWCRNG] rounds = 10000 + entry_name = 'MWC_RNG_test' + entry_params = '' - prelude = ".global .u64 mwc_rng_test_sums[{{ctx.threads}}];" + def module_setup(self): + mem.global_.u64(mwc_rng_test_sums, ctx.threads) - def entry(self, ctx): - return ('MWC_RNG_test', '', """ - .reg .u64 sum, addl; - .reg .u32 addend; - mov.u64 sum, 0; - { - .reg .u32 loopct; - .reg .pred p; - mov.u32 loopct, %s; -loopstart: - {{ mwc_next_b32('addend') }} - cvt.u64.u32 addl, addend; - add.u64 sum, sum, addl; - sub.u32 loopct, loopct, 1; - setp.gt.u32 p, loopct, 0; - @p bra.uni loopstart; - } - { - .reg .u32 addr, offset; - {{ get_gtid('offset') }} - mov.u32 addr, mwc_rng_test_sums; - mad.lo.u32 addr, offset, 8, addr; - st.global.u64 [addr], sum; - } - """ % self.rounds) + @ptx_func + def entry(self): + reg.u64('sum addl') + reg.u32('addend') + op.mov.u64(sum, 0) + with block('Sum next %d random numbers' % self.rounds): + reg.u32('loopct') + pred('p') + op.mov.u32(loopct, self.rounds) + label('loopstart') + mwc_next_b32(addend) + op.cvt.u64.u32(addl, addend) + op.add.u64(sum, sum, addl) + op.sub.u32(loopct, loopct, 1) + op.setp.gt.u32(p, loopct, 0) + op.bra.uni(loopstart, ifp=p) + + with block('Store sum and state'): + reg.u32('adr offset') + get_gtid(offset) + op.mov.u32(adr, mwc_rng_test_sums) + op.mad.lo.u32(adr, offset, 8, adr) + st.global_.u64(addr(adr), sum) def call(self, ctx): # Get current multipliers and seeds from the device diff --git a/cuburnlib/ptx.py b/cuburnlib/ptx.py index e9e40f3..9deb525 100644 --- a/cuburnlib/ptx.py +++ b/cuburnlib/ptx.py @@ -1,38 +1,545 @@ +""" +PTX DSL, a domain-specific language for NVIDIA's PTX. + +The DSL doesn't really provide any benefits over raw PTX in terms of type +safety or error checking. Where it shines is in enabling code reuse, +modularization, and dynamic data structures. In particular, the "data stream" +that controls the iterations and xforms in cuflame's device code are much +easier to maintain using this system. +""" + +# If you see 'import inspect', you know you're in for a good time +import inspect import ctypes -import tempita +from collections import namedtuple -def ppr_ptx(src): - # TODO: Add variable realignment - indent = 0 - out = [] - for line in [l.strip() for l in src.split('\n')]: - if not line: - continue - if len(line.split()) == 1 and line.endswith(':'): - out.append(line) - continue - if '}' in line and '{' not in line: - indent -= 1 - if line.startswith('@'): - out.append(' ' * ((indent - 1) * 4) + line) +# Okay, so here's what's going on. +# +# We're using Python to create PTX. If we just use Python to make one giant PTX +# module, there's no real reason of going to the trouble of using Python to +# begin with, as the things that this system is good for - modularization, unit +# testing, automated analysis, and data structure generation and optimization - +# pretty much require splitting code up into manageable units. However, +# splitting things up at the level of PTX will greatly reduce performance, as +# the cost of accessing the stack, spilling registers, and reloading data from +# system memory is unacceptably high even on Fermi GPUs. So we want to split +# code up into functions within Python, but not within the PTX. +# +# The challenge here is variable lifetime. A PTX function might declare a +# register at the top of the main block and use it several times throughout the +# function. In Python, we split that up into multiple functions, one to declare +# the registers at the start of the scope and another to make use of them later +# on. This makes it very easy to reuse a class of related PTX functions in +# different device entry points, do unit tests, and so on. +# +# The scope of the class instance is unrelated to the normal scope of names in +# Python. In fact, a function call frequently declares a register that may be +# needed by the parent function. So where to store the information regarding +# the register that was declared at the top of the file (name, type, etc)? +# Well, once declared, a variable remains in scope in PTX until the closing +# brace of the block (curly-braces segment) it was declared in. The natural +# place to store it would be in a Pythonic representation of the block: a block +# object that implements the context manager. +# +# This works well in terms of tracking object lifetime, but it adds a great +# deal of ugliness to the code. What I originally sought was this:: +# +# def load_zero(dest_reg): +# op.mov.u32(dest_reg, 0) +# def init_module(): +# reg.u32('hooray_reg') +# load_zero(hooray_reg) +# +# But using blocks to track state, it would turn in to this ugliness:: +# +# def load_zero(block, dest_reg): +# block.op.mov.u32(op.dest_reg, 0) +# def init_module(): +# with Block() as block: +# block.regs.hooray_reg = block.reg.u32('hooray_reg') +# load_zero(block, block.regs.hooray_reg) +# +# Eeugh. +# +# Anyway, never one to use an acceptable solution when an ill-conceived hack +# was available, I poked and prodded until I found a way to attain my ideal. +# In short, a function with a 'ptx_func' decorator will be wrapped in a +# _BlockInjector context manager, which will temporarily add values to the +# function's global dictionary in such a way as to mimic the desired behavior. +# The decorator is kind enough to pop the values when exiting. The examples +# below give a clear picture of how to use it, but now you know why this +# abomination was crafted to begin with. + + +BlockCtx = namedtuple('BlockCtx', 'locals code injectors') +PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent') + +class _BlockInjector(object): + """ + A ContextManager that, upon entering a context, loads some keys into a + dictionary, and upon leaving it, removes those keys. If any keys are + already in the destination dictionary with a different value, an exception + is raised. + + Useful if the destination dictionary is a func's __globals__. + """ + def __init__(self, to_inject, inject_into): + self.to_inject, self.inject_into = to_inject, inject_into + self.injected = set() + self.dead = True + def inject(self, kv, v=None): + """Inject a key-value pair (passed either as a tuple or separately.)""" + k, v = v and (kv, v) or kv + if k not in self.to_inject: + self.to_inject[k] = v + if self.dead: + return + if k in self.inject_into: + if self.inject_into[k] is not v: + raise KeyError("Key with different value already in dest") else: - out.append(' ' * (indent * 4) + line) - if '{' in line and '}' not in line: - indent += 1 - return '\n'.join(out) + self.inject_into[k] = v + self.injected.add(k) + def __enter__(self): + self.dead = False + map(self.inject, self.to_inject.items()) + def __exit__(self, exc_type, exc_val, tb): + for k in self.injected: + del self.inject_into[k] + self.dead = True -def multisub(tmpl, subs): - while '{{' in tmpl: - tmpl = tempita.Template(tmpl).substitute(subs) - return tmpl +class _Block(object): + """ + State-tracker for PTX fragments. You should really look at Block and + PTXModule instead of here. -class PTXAssembler(object): + For important reasons, the instance must be bound locally as "_block". + """ + name = '_block' + def __init__(self): + self.outer_ctx = BlockCtx({self.name: self}, [], []) + self.stack = [self.outer_ctx] + def push_ctx(self): + self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], [])) + def pop_ctx(self): + bs = self.stack.pop() + self.stack[-1].code.append(bs.code) + def injector(self, func_globals): + inj = BlockInjector(self.stack[-1].locals, func_globals) + self.stack[-1].injectors.append(inj) + return inj + def inject(self, name, object): + if name in self.stack[-1].locals: + raise KeyError("Duplicate name already exists in this scope.") + self.stack[-1].locals[name] = object + [inj.inject(name, object) for inj in self.stack[-1].injectors] + def code(self, prefix='', op='', vars=[], semi=True, indent=0): + """ + Append a PTX statement (or thereabouts) to the current block. + + - `prefix`: a string which will not be indented, regardless of the + current indent level, for labels and predicates. + - `op`: a string, aligned to current indent level. + - `vars`: a list of strings, with best-effort alignment. + - `semi`: whether to terminate the current line with a semicolon. + - `indent`: integer adjustment to the current indent level. + + For `prefix`, `op`, and `vars`, a "string" can also mean a sequence of + objects that can be coerced to strings, which will be joined without + spacing. To keep things simple, nested lists and tuples will be reduced + in this manner (but not other iterable types). Coercion will not happen + until after the entire DSL call tree has been walked. This allows a + class to submit a mutable type (e.g. the trivial `StrVar`) when first + walked with an undefined value, then substitute the correct value on + being finalized. + + Details about alignment are available in the `PTXFormatter` class. And + yes, the only real difference between `prefix`, `op`, and `vars` is in + final appearance, but it is in fact quite helpful for debugging. + """ + self.stack[-1].append(PTXStmt(prefix, op, vars, indent)) + +class StrVar(object): + """ + Trivial wrapper to allow deferred variable substitution. + """ + def __init__(self, val=None): + self.val = val + def __str__(self): + return str(val) + +def ptx_func(func): + """ + Decorator function for code in the DSL. Any function which accesses the DSL + namespace, including declared device variables and objects such as "reg" + or "op", should be wrapped with this. See Block for some examples. + """ + def ptx_eval(*args, **kwargs): + if self.name not in globals(): + parent = inspect.stack()[-2][0] + if self.name in parent.f_locals: + block = parent.f_locals[self.name] + elif self.name in parent.f_globals: + block = parent.f_globals[self.name] + else: + # Couldn't find the _block instance. Fail cryptically to + # encourage users to read the source (for now) + raise SyntaxError("Black magic") + else: + block = globals()['block'] + with block.injector(func.func_globals): + func(*args, **kwargs) + return ptx_eval + +class Block(object): + """ + Limits the lifetime of variables in both PTX (using curly-braces) and in + the Python DSL (via black magic). This is semantically useful, but should + not otherwise affect device code (the lifetime of a register is + aggressively minimized by the compiler). + + >>> with block('This comment will appear at the top of the block'): + >>> reg.u32('same_name') + >>> with block(): + >>> reg.u64('same_name') # OK, because 'same_name' went out of scope + + PTX variables declared inside a block will be available in any other + ptx_func called within that block. Note that this flies in the face of + normal Python behavior! That's why it's a DSL. (This doesn't apply to + non-PTX variables.) + + >>> @ptx_func + >>> def fn1(): + >>> op.mov.u32(reg1, 0) + >>> + >>> @ptx_func + >>> def fn2(): + >>> print x + >>> + >>> @ptx_func + >>> def fn3(): + >>> with block(): + >>> reg.u32('reg1') + >>> x = 4 + >>> fn1() # OK: DSL magic propagates 'reg1' to fn1's namespace + >>> fn2() # FAIL: DSL magic doesn't touch regular variables + >>> fn1() # FAIL: 'reg1' went out of scope along with the block + + This constructor is available as 'block' in the DSL namespace. + """ + def __init__(self, block): + # `block` is the real _block + self.block = block + self.comment = None + def __call__(self, comment=None) + self.comment = comment + return self + def __enter__(self): + self.block.push_ctx() + self.block.code(op='{', indent=4) + def __exit__(self, exc_type, exc_value, tb): + self.block.code(op='}', indent=-4) + self.block.pop_ctx() + +class _CallChain(object): + """Handles the syntax for the operator chaining in PTX, like op.mul.u32.""" + def __init__(self, block): + self.block = block + self.__chain = [] + def __call__(self, *args, **kwargs): + assert(self.__chain) + self._call(chain, *args, **kwargs) + self.__chain = [] + def __getattr__(self, name): + if name == 'global_': + name = 'global' + self.chain.append(name) + # Another great crime against the universe: + return self + +class Reg(object): + """ + Creates one or more registers. The argument should be a string containing + one or more register names, separated by whitespace; the registers will be + injected into the DSL namespace on creation, so you do not need to + rebind them to the same name before use. + + >>> with block(): + >>> reg.u32('addend product') + >>> op.mov.u32(addend, 0) + >>> op.mov.u32(product, 0) + >>> op.mov.u32(addend, 1) # Fails, block unbinds globals on leaving scope + + This constructor is available as 'reg' in the DSL namespace. + """ + def __init__(self, type, name): + self.type, self.name = type, name + def __str__(self): + return self.name + +class _RegFactory(_CallChain): + """The actual 'reg' object in the DSL namespace.""" + def _call(self, type, names): + assert len(type) == 1 + type = type[0] + names = names.split() + regs = map(lambda n: Reg(type, n), names) + self.block.code(op='.reg .' + type, vars=names) + [self.block.inject(r.name, r) for r in regs] + +# Pending resolution of the op(regs, guard=x) debate +#class Pred(object): + #""" + #Allows for predicated execution of operations. + + #>>> pred('p_some_test p_another_test') + #>>> op.setp.eq.u32(p_some_test, reg1, reg2) + #>>> op.setp.and.eq.u32(p_another_test, reg1, reg2, p_some_test) + #>>> with p_some_test.is_set(): + #>>> op.ld.global.u32(reg1, addr(areg)) + + #Predication supports nested function calls, and will cover all code + #generated inside the predicate block: + + #>>> with p_another_test.is_unset(): + #>>> some_ptxdsl_function(reg2) + #>>> op.st.global.u32(addr(areg), reg2) + + #It is a syntax error to declare registers, + #However, multiple predicate blocks cannot be nested. Doing so is a syntax + #error. + + #>>> with p_some_test.is_set(): + #>>> with p_another_test.is_unset(): + #>>> pass + #SyntaxError: ... + #""" + #def __init__(self, name): + #self.name = name + #def is_set(self, isnot=False): + + + +class Op(_CallChain): + """ + Performs an operation. + + >>> op.mov.u32(address, mwc_rng_test_sums) + >>> op.mad.lo.u32(address, offset, 8, address) + >>> op.st.global_.v2.u32(addr(address), vec(mwc_a, mwc_b)) + + To make an operation conditional on a predicate, use 'ifp' or 'ifnotp': + + >>> reg.pred('p1') + >>> op.setp.eq.u32(p1, reg1, reg2) + >>> op.mul.lo.u32(reg1, reg1, reg2, ifp=p1) + >>> op.add.u32(reg2, reg1, reg2, ifnotp=p1) + + Note that the global state-space should be written 'global_' to avoid + conflict with the Python keyword. `addr` and `vec` are defined in Mem. + + This constructor is available as 'op' in DSL blocks. + """ + def _call(self, op, *args, ifp=None, ifnotp=None): + pred = '' + if ifp: + if ifnotp: + raise SyntaxError("can't use both, fool") + pred = ['@', ifp] + if ifnotp: + pred = ['@!', ifnotp] + self.block.append_code(pred, '.'.join(op), map(str, args)) + +class Mem(object): + """ + Reserve memory, optionally with an array size attached. + + >>> mem.global_.u32('global_scalar') + >>> mem.local.u32('context_sized_local_array', ctx.threads*4) + >>> mem.shared.u32('shared_array', 12) + >>> mem.const.u32('const_array_of_unknown_length', True) + + Like registers, memory allocations are injected into the global namespace + for use by any functions inside the scope without extra effort. + + >>> with block('move address into memory'): + >>> reg.u32('mem_address') + >>> op.mov.u32(mem_address, global_scalar) + + This constructor is available as 'mem' in DSL blocks. + """ + # Pretty much the same as 'Reg', duplicated only for clarity + def __init__(self, type, name, array, init): + self.type, self.name, self.array, self.init = type, name, array, init + def __str__(self): + return self.name + + @staticmethod + def vec(*args): + """ + Prepare vector arguments to a memory operation. + + >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg)) + """ + return ['{', [(a, ', ') for a in args][:-1], '}'] + + @staticmethod + def addr(areg, aoffset=''): + """ + Prepare an address to a memory operation, optionally specifying offset. + + >>> op.st.global.v2.u32(addr(areg), vec(reg1, reg2)) + >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg, 8)) + """ + return ['[', areg, aoffset and '+' or '', aoffset, ']'] + +class _MemFactory(_CallChain): + """Actual `mem` object""" + def _call(self, type, name, array=False, initializer=None): + assert len(type) == 2 + memobj = Mem(type, name, array) + self.dsl.inject(name, memobj) + if array is True: + array = ['[]'] + elif array: + array = ['[', array, ']'] + else: + array = [] + if initializer: + array += [' = ', initializer] + self.block.code(op=['.%s.%s ' % type, name, array]) + +class Label(object): + """ + Specifies the target for a branch. Scoped in PTX? TODO: test. + + >>> label('infinite_loop') + >>> op.bra.uni('label') + """ + def __init__(self, name): + self.name = name + def __str__(self): + return self.name + +class _LabelFactory(object): + def __init__(self, block): + self.block = block + def __call__(self, name): + self.block.inject(name, Label(name)) + +class PTXFragment(object): + def module_setup(self): + pass + + def entry_setup(self): + pass + + def entry_teardown(self): + pass + + def globals(self): + pass + + def tests(self): + pass + + def device_init(self, ctx): + pass + +class PTXFragment(object): + """ + An object containing PTX DSL functions. + + In cuflame, several different versions of a given function may be + regenerated in rapid succession + + The final compilation pass is guaranteed to have all "tuned" values fixed + in their final values for the stream. + + Template code will be processed recursively until all "{{" instances have + been replaced, using the same namespace each time. + + Note that any method which does not depend on 'ctx' can be replaced with + an instance of the appropriate return type. So, for example, the 'deps' + property can be a flat list instead of a function. + """ + + def deps(self): + """ + Returns a list of PTXFragment types on which this object depends + for successful compilation. Circular dependencies are forbidden, + but multi-level dependencies should be fine. + """ + return [DeviceHelpers] + + def inject(self): + """ + Returns a dict of items to add to the DSL namespace. The namespace will + be assembled in dependency order before any ptx_funcs are called. + """ + return {} + + def module_setup(self): + """ + PTX function to declare things at module scope. It's a PTX syntax error + to perform operations at this scope, but we don't yet validate that at + the Python level. A module will call this function on all fragments in + dependency order. + + If implemented, this function should use an @ptx_func decorator. + """ + pass + + def entry_setup(self): + """ + PTX DSL function which will insert code at the start of an entry, for + initializing variables and stuff like that. An entry point will call + this function on all fragments used in that entry point in dependency + order. + + If implemented, this function should use an @ptx_func decorator. + """ + pass + + def entry_teardown(self): + """ + PTX DSL function which will insert code at the end of an entry, for any + clean-up that needs to be performed. An entry point will call this + function on all fragments used in the entry point in *reverse* + dependency order (i.e. fragments which this fragment depends on will be + cleaned up after this one). + + If implemented, this function should use an @ptx_func decorator. + """ + pass + + def tests(self, ctx): + """ + Returns a list of PTXTest classes which will test this fragment. + """ + return [] + + def set_up(self, ctx): + """ + Do start-of-stream initialization, such as copying data to the device. + """ + pass +class PTXModule(object): """ Assembles PTX fragments into a module. """ - def __init__(self, ctx, entries, build_tests=False): - self.assemble(ctx, entries, build_tests) + def __init__(self, entries, inject={}, build_tests=False): + self._block = b = _Block() + self.initial_inject = dict(inject) + self._safeupdate(self.initial_inject, dict(block=Block(b), + mem=_MemFactory(b), reg=_RegFactory(b), op=Op(b), + label=_LabelFactory(b), _block=b) + self.needs_recompilation = True + self.max_compiles = 10 + while self.needs_recompilation: + self.assemble(entries, build_tests) + self.max_compiles -= 1 def deporder(self, unsorted_instances, instance_map, ctx): """ @@ -57,7 +564,7 @@ class PTXAssembler(object): if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key)) dst.update(src) - def assemble(self, ctx, entries, build_tests): + def assemble(self, entries, build_tests): """ Build the PTX source for the given set of entries. """ @@ -121,78 +628,7 @@ class PTXAssembler(object): self.instances = instances self.tests = tests -class PTXFragment(object): - """ - Wrapper for sections of template PTX. - In order to provide the best optimization, and avoid a web of hard-coded - parameters, the PTX module may be regenerated and recompiled several times - with different or incomplete launch context parameters. To this end, avoid - accessing the GPU in such functions, and do not depend on context values - which are marked as "tuned" in the LaunchContext docstring being - available. - - The final compilation pass is guaranteed to have all "tuned" values fixed - in their final values for the stream. - - Template code will be processed recursively until all "{{" instances have - been replaced, using the same namespace each time. - - Note that any method which does not depend on 'ctx' can be replaced with - an instance of the appropriate return type. So, for example, the 'deps' - property can be a flat list instead of a function. - """ - - def deps(self, ctx): - """ - Returns a list of PTXFragment objects on which this object depends - for successful compilation. Circular dependencies are forbidden, - but multi-level dependencies should be fine. - """ - return [DeviceHelpers] - - def subs(self, ctx): - """ - Returns a dict of items to add to the template substitution namespace. - The entire dict will be assembled, including all dependencies, before - any templates are evaluated. - """ - return {} - - def prelude(self, ctx): - """ - Returns a template string containing any code (variable declarations, - probably) that should be inserted at module scope. The prelude of - all deps will be inserted above this prelude. - """ - return "" - - def entry_start(self, ctx): - """ - Returns a template string that should be inserted at the top of any - entry point which depends on this method. The entry starts of all - deps will be inserted above this entry prelude. - """ - return "" - - def entry_end(self, ctx): - """ - As above, but at the end of the calling function, and with the order - reversed (all dependencies will be inserted after this). - """ - return "" - - def tests(self, ctx): - """ - Returns a list of PTXTest classes which will test this fragment. - """ - return [] - - def set_up(self, ctx): - """ - Do start-of-stream initialization, such as copying data to the device. - """ - pass class PTXEntryPoint(PTXFragment): # Human-readable entry point name