Switch from to_inject() to object insertion. One less kludge to deal with.

This commit is contained in:
Steven Robertson 2010-09-06 16:09:37 -04:00
parent ada0fe20c7
commit e03f20392d
3 changed files with 63 additions and 85 deletions

1
TODO
View File

@ -39,6 +39,7 @@ Things to do (rather severely incomplete):
- Test effects on quality by masking off writes on all but one lane and - Test effects on quality by masking off writes on all but one lane and
boosting the sample density to compensate (muuuuuch later on) boosting the sample density to compensate (muuuuuch later on)
- DE - DE
- Clean up code (particularly DSL stuff incl. injector)
Things to test: Things to test:

View File

@ -24,7 +24,7 @@ class IterThread(PTXTest):
@ptx_func @ptx_func
def module_setup(self): def module_setup(self):
mem.global_.u32('g_cp_array', mem.global_.u32('g_cp_array',
cp_stream_size*features.max_ntemporal_samples) cp.stream_size*features.max_ntemporal_samples)
mem.global_.u32('g_num_cps') mem.global_.u32('g_num_cps')
# TODO move into debug statement # TODO move into debug statement
mem.global_.u32('g_num_rounds', ctx.threads) mem.global_.u32('g_num_rounds', ctx.threads)
@ -40,10 +40,10 @@ class IterThread(PTXTest):
op.mov.u32(num_writes, 0) op.mov.u32(num_writes, 0)
# TODO: MWC float output types # TODO: MWC float output types
mwc_next_f32_01(x_coord) mwc.next_f32_01(x_coord)
mwc_next_f32_01(y_coord) mwc.next_f32_01(y_coord)
mwc_next_f32_01(color_coord) mwc.next_f32_01(color_coord)
mwc_next_f32_01(alpha_coord) mwc.next_f32_01(alpha_coord)
# Registers are hard to come by. To avoid having to track both the count # Registers are hard to come by. To avoid having to track both the count
# of samples processed and the number of samples to generate, # of samples processed and the number of samples to generate,
@ -81,13 +81,13 @@ class IterThread(PTXTest):
with block('Load CP address'): with block('Load CP address'):
op.mov.u32(cpA, g_cp_array) op.mov.u32(cpA, g_cp_array)
op.mad.lo.u32(cpA, cp_idx, cp_stream_size, cpA) op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
with block('Increment CP index, load num_samples (unless in fuse)'): with block('Increment CP index, load num_samples (unless in fuse)'):
reg.pred('p_not_in_fuse') reg.pred('p_not_in_fuse')
op.setp.ge.s32(p_not_in_fuse, num_samples, 0) op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse) op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
cp_stream_get(cpA, num_samples, 'samples_per_thread', cp.get(cpA, num_samples, 'samples_per_thread',
ifp=p_not_in_fuse) ifp=p_not_in_fuse)
label('fuse_loop_start') label('fuse_loop_start')
@ -127,8 +127,8 @@ class IterThread(PTXTest):
label('all_cps_done') label('all_cps_done')
# TODO this is for testing, move it to a debug statement # TODO this is for testing, move it to a debug statement
store_per_thread(g_num_rounds, num_rounds) std.store_per_thread(g_num_rounds, num_rounds)
store_per_thread(g_num_writes, num_writes) std.store_per_thread(g_num_writes, num_writes)
def upload_cp_stream(self, ctx, cp_stream, num_cps): def upload_cp_stream(self, ctx, cp_stream, num_cps):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array') cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
@ -152,6 +152,8 @@ class IterThread(PTXTest):
print "Writes:", writes print "Writes:", writes
class MWCRNG(PTXFragment): class MWCRNG(PTXFragment):
shortname = "mwc"
def __init__(self): def __init__(self):
self.rand = np.random self.rand = np.random
self.threads_ready = 0 self.threads_ready = 0
@ -171,7 +173,7 @@ class MWCRNG(PTXFragment):
reg.u32('mwc_st mwc_mult mwc_car') reg.u32('mwc_st mwc_mult mwc_car')
with block('Load MWC multipliers and states'): with block('Load MWC multipliers and states'):
reg.u32('mwc_off mwc_addr') reg.u32('mwc_off mwc_addr')
get_gtid(mwc_off) std.get_gtid(mwc_off)
op.mov.u32(mwc_addr, mwc_rng_mults) op.mov.u32(mwc_addr, mwc_rng_mults)
op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr) op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr)
op.ld.global_.u32(mwc_mult, addr(mwc_addr)) op.ld.global_.u32(mwc_mult, addr(mwc_addr))
@ -184,7 +186,7 @@ class MWCRNG(PTXFragment):
def entry_teardown(self): def entry_teardown(self):
with block('Save MWC states'): with block('Save MWC states'):
reg.u32('mwc_off mwc_addr') reg.u32('mwc_off mwc_addr')
get_gtid(mwc_off) std.get_gtid(mwc_off)
op.mov.u32(mwc_addr, mwc_rng_state) op.mov.u32(mwc_addr, mwc_rng_state)
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car)) op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
@ -219,11 +221,6 @@ class MWCRNG(PTXFragment):
op.cvt.rn.f32.s32(dst_reg, mwc_st) op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31) op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
def to_inject(self):
return dict(mwc_next_b32=self.next_b32,
mwc_next_f32_01=self.next_f32_01,
mwc_next_f32_11=self.next_f32_11)
def device_init(self, ctx): def device_init(self, ctx):
if self.threads_ready >= ctx.threads: if self.threads_ready >= ctx.threads:
# Already set up enough random states, don't push again # Already set up enough random states, don't push again
@ -275,7 +272,7 @@ class MWCRNGTest(PTXTest):
reg.pred('p') reg.pred('p')
op.mov.u32(loopct, self.rounds) op.mov.u32(loopct, self.rounds)
label('loopstart') label('loopstart')
mwc_next_b32(addend) mwc.next_b32(addend)
op.cvt.u64.u32(addl, addend) op.cvt.u64.u32(addl, addend)
op.add.u64(sum, sum, addl) op.add.u64(sum, sum, addl)
op.sub.u32(loopct, loopct, 1) op.sub.u32(loopct, loopct, 1)
@ -284,7 +281,7 @@ class MWCRNGTest(PTXTest):
with block('Store sum and state'): with block('Store sum and state'):
reg.u32('adr offset') reg.u32('adr offset')
get_gtid(offset) std.get_gtid(offset)
op.mov.u32(adr, mwc_rng_test_sums) op.mov.u32(adr, mwc_rng_test_sums)
op.mad.lo.u32(adr, offset, 8, adr) op.mad.lo.u32(adr, offset, 8, adr)
op.st.global_.u64(addr(adr), sum) op.st.global_.u64(addr(adr), sum)
@ -331,5 +328,5 @@ class CameraCoordTransform(PTXFragment):
class CPDataStream(DataStream): class CPDataStream(DataStream):
"""DataStream which stores the control points.""" """DataStream which stores the control points."""
prefix = 'cp' shortname = 'cp'

View File

@ -493,22 +493,14 @@ class Comment(object):
class PTXFragment(object): class PTXFragment(object):
""" """
An object containing PTX DSL functions. An object containing PTX DSL functions. The object, and all its
dependencies, will be instantiated by a PTX module. Each object will be
In cuflame, several different versions of a given function may be bound to the name given by ``shortname`` in the DSL namespace.
regenerated in rapid succession
The final compilation pass is guaranteed to have all "tuned" values fixed
in their final values for the stream.
Template code will be processed recursively until all "{{" instances have
been replaced, using the same namespace each time.
Note that any method which does not depend on 'ctx' can be replaced with
an instance of the appropriate return type. So, for example, the 'deps'
property can be a flat list instead of a function.
""" """
# Name under which to make this code available in ptx_funcs
shortname = None
def deps(self): def deps(self):
""" """
Returns a list of PTXFragment types on which this object depends Returns a list of PTXFragment types on which this object depends
@ -517,15 +509,6 @@ class PTXFragment(object):
""" """
return [_PTXStdLib] return [_PTXStdLib]
def to_inject(self):
"""
Returns a dict of items to add to the DSL namespace. The namespace will
be assembled in dependency order before any ptx_funcs are called.
This is only called once per PTXModule (== once per instance).
"""
return {}
def module_setup(self): def module_setup(self):
""" """
PTX function to declare things at module scope. It's a PTX syntax error PTX function to declare things at module scope. It's a PTX syntax error
@ -624,6 +607,7 @@ class PTXTest(PTXEntryPoint):
pass pass
class _PTXStdLib(PTXFragment): class _PTXStdLib(PTXFragment):
shortname = "std"
def __init__(self, block): def __init__(self, block):
# Only module that gets the privilege of seeing 'block' directly. # Only module that gets the privilege of seeing 'block' directly.
self.block = block self.block = block
@ -639,7 +623,7 @@ class _PTXStdLib(PTXFragment):
self.block.code(prefix='.target sm_20', semi=False) self.block.code(prefix='.target sm_20', semi=False)
@ptx_func @ptx_func
def _get_gtid(self, dst): def get_gtid(self, dst):
""" """
Get the global thread ID (the position of this thread in a grid of Get the global thread ID (the position of this thread in a grid of
blocks of threads). Notably, this assumes that both grid and block are blocks of threads). Notably, this assumes that both grid and block are
@ -661,16 +645,17 @@ class _PTXStdLib(PTXFragment):
op.mov.b32(dst, gtid) op.mov.b32(dst, gtid)
@ptx_func @ptx_func
def _store_per_thread(self, base, val): def store_per_thread(self, base, val):
"""Store b32 at `base+gtid*4`. Super-common debug pattern.""" """Store b32 at `base+gtid*4`. Super-common debug pattern."""
with block("Per-thread store of %s" % str(val)): with block("Per-thread store of %s" % str(val)):
reg.u32('spt_base spt_offset') reg.u32('spt_base spt_offset')
op.mov.u32(spt_base, base) op.mov.u32(spt_base, base)
get_gtid(spt_offset) self.get_gtid(spt_offset)
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base) op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
op.st.b32(addr(spt_base), val) op.st.b32(addr(spt_base), val)
def to_inject(self): def to_inject(self):
# Set up the initial namespace
return dict( return dict(
_block=self.block, _block=self.block,
block=Block(self.block), block=Block(self.block),
@ -680,9 +665,7 @@ class _PTXStdLib(PTXFragment):
addr=Mem.addr, addr=Mem.addr,
vec=Mem.vec, vec=Mem.vec,
label=_LabelFactory(self.block), label=_LabelFactory(self.block),
comment=Comment(self.block), comment=Comment(self.block))
get_gtid=self._get_gtid,
store_per_thread=self._store_per_thread)
class PTXModule(object): class PTXModule(object):
""" """
@ -713,9 +696,11 @@ class PTXModule(object):
self.tests = tests self.tests = tests
inject = dict(inject) inject = dict(inject)
self._safeupdate(inject, {'module': self}) inject.update(insts[_PTXStdLib].to_inject())
self._safeupdate(inject, 'module', self)
for inst in all_deps: for inst in all_deps:
self._safeupdate(inject, inst.to_inject()) if inst.shortname:
self._safeupdate(inject, inst.shortname, inst)
[block.inject(k, v) for k, v in inject.items()] [block.inject(k, v) for k, v in inject.items()]
self.__needs_recompilation = True self.__needs_recompilation = True
@ -749,11 +734,9 @@ class PTXModule(object):
map(rec, unsorted_instances) map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get) return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src): def _safeupdate(self, dst, k, v):
"""dst.update(src), but no duplicates allowed""" if k in dst: raise KeyError("Duplicate key %s" % k)
non_uniq = [k for k in src if k in dst] dst[k] = v
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
dst.update(src)
def deptrace(self, block, entries, build_tests): def deptrace(self, block, entries, build_tests):
instances = {_PTXStdLib: _PTXStdLib(block)} instances = {_PTXStdLib: _PTXStdLib(block)}
@ -796,6 +779,7 @@ class PTXModule(object):
def assemble(self, block, all_deps, entry_deps): def assemble(self, block, all_deps, entry_deps):
# Rebind to local namespace to allow proper retrieval # Rebind to local namespace to allow proper retrieval
_block = block _block = block
for inst in all_deps: for inst in all_deps:
inst.module_setup() inst.module_setup()
@ -886,11 +870,11 @@ class DataStream(PTXFragment):
variable positions determined at runtime. The resulting structure had to be variable positions determined at runtime. The resulting structure had to be
read strictly sequentially to be parsed, hence the name "stream".) read strictly sequentially to be parsed, hence the name "stream".)
Subclass this and give it a prefix, then depend on the subclass in your PTX Subclass this and give it a shortname, then depend on the subclass in your
fragments. PTX fragments. An instance-based approach is under consideration.
>>> class ExampleDataStream(DataStream): >>> class ExampleDataStream(DataStream):
>>> prefix = 'ex' >>> shortname = "ex"
Inside DSL functions, you can "retrieve" arbitrary Python expressions from Inside DSL functions, you can "retrieve" arbitrary Python expressions from
the data stream. the data stream.
@ -901,7 +885,7 @@ class DataStream(PTXFragment):
>>> op.mov.u32(regA, some_device_allocation_base_address) >>> op.mov.u32(regA, some_device_allocation_base_address)
>>> # From the structure at the base address in 'regA', load the value >>> # From the structure at the base address in 'regA', load the value
>>> # of 'ctx.nthreads' into reg1 >>> # of 'ctx.nthreads' into reg1
>>> ex_stream_get(regA, reg1, 'ctx.nthreads') >>> ex.get(regA, reg1, 'ctx.nthreads')
The expressions will be stored as strings and mapped to particular The expressions will be stored as strings and mapped to particular
positions in the struct. Later, the expressions will be evaluated and positions in the struct. Later, the expressions will be evaluated and
@ -922,32 +906,32 @@ class DataStream(PTXFragment):
>>> def example_func_2(): >>> def example_func_2():
>>> reg.u32('reg1 reg2') >>> reg.u32('reg1 reg2')
>>> reg.f32('regf') >>> reg.f32('regf')
>>> ex_stream_get(regA, reg1, 'ctx.nthreads * 2') >>> ex.get(regA, reg1, 'ctx.nthreads * 2')
>>> # Same expression, so load comes from same memory location >>> # Same expression, so load comes from same memory location
>>> ex_stream_get(regA, reg2, 'ctx.nthreads * 2') >>> ex.get(regA, reg2, 'ctx.nthreads * 2')
>>> # Vector loads are pre-coerced, so you can mix types >>> # Vector loads are pre-coerced, so you can mix types
>>> ex_stream_get_v2(regA, reg1, '4', regf, '5.5') >>> ex.get_v2(regA, reg1, '4', regf, '5.5')
You can even do device allocations in the file, using the post-finalized You can even do device allocations in the file, using the post-finalized
variable '[prefix]_stream_size'. It's a StrVar, so if you do any operations variable '[prefix]_stream_size'. It's a DelayVar; simple things like
to it make sure you write them as a list of strings for PTX to handle (I multiplying by a number work (as long as the DelayVar comes first), but
know, it's a drag; it might be fixed later): fancy things like multiplying two DelayVars aren't implemented yet.
>>> class Whatever(PTXFragment): >>> class Whatever(PTXFragment):
>>> @ptx_func >>> @ptx_func
>>> def module_setup(self): >>> def module_setup(self):
>>> mem.global_.u32('ex_streams', [1000, '*', ex_stream_size]) >>> mem.global_.u32('ex_streams', ex.stream_size*1000)
""" """
# Must be at least as large as the largest load (.v4.u32 = 16) # Must be at least as large as the largest load (.v4.u32 = 16)
alignment = 16 alignment = 16
prefix = 'Subclass this'
def __init__(self): def __init__(self):
self.texp_map = {} self.texp_map = {}
self.cells = [] self.cells = []
self.stream_size = 0 self._size = 0
self.free = {} self.free = {}
self.size_delayvars = [] self.size_delayvars = []
self.finalized = False
_types = dict(s8='b', u8='B', s16='h', u16='H', s32='i', u32='I', f32='f', _types = dict(s8='b', u8='B', s16='h', u16='H', s32='i', u32='I', f32='f',
s64='l', u64='L', f64='d') s64='l', u64='L', f64='d')
@ -973,8 +957,8 @@ class DataStream(PTXFragment):
# No aligned free cells, allocate a new `align`-byte free cell # No aligned free cells, allocate a new `align`-byte free cell
assert alloc not in self.free assert alloc not in self.free
self.free[alloc] = idx = len(self.cells) self.free[alloc] = idx = len(self.cells)
self.cells.append(_DataCell(self.stream_size, alloc, None)) self.cells.append(_DataCell(self._size, alloc, None))
self.stream_size += alloc self._size += alloc
# Overwrite the free cell at `idx` with texp # Overwrite the free cell at `idx` with texp
assert self.cells[idx].texp is None assert self.cells[idx].texp is None
offset = self.cells[idx].offset offset = self.cells[idx].offset
@ -1015,43 +999,40 @@ class DataStream(PTXFragment):
op._call(opname, dregs, addr(areg, offset), ifp=ifp, ifnotp=ifnotp) op._call(opname, dregs, addr(areg, offset), ifp=ifp, ifnotp=ifnotp)
@ptx_func @ptx_func
def _stream_get(self, areg, dreg, expr, ifp=None, ifnotp=None): def get(self, areg, dreg, expr, ifp=None, ifnotp=None):
self._stream_get_internal(areg, [dreg], [expr], ifp, ifnotp) self._stream_get_internal(areg, [dreg], [expr], ifp, ifnotp)
@ptx_func @ptx_func
def _stream_get_v2(self, areg, dreg1, expr1, dreg2, expr2, def get_v2(self, areg, dreg1, expr1, dreg2, expr2, ifp=None, ifnotp=None):
ifp=None, ifnotp=None):
self._stream_get_internal(areg, [dreg1, dreg2], [expr1, expr2], self._stream_get_internal(areg, [dreg1, dreg2], [expr1, expr2],
ifp, ifnotp) ifp, ifnotp)
# The interleaved signature makes calls easier to read
@ptx_func @ptx_func
def _stream_get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4, def get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4,
ifp=None, ifnotp=None): ifp=None, ifnotp=None):
self._stream_get_internal(areg, [d1, d2, d3, d4], [e1, e2, e3, e4], self._stream_get_internal(areg, [d1, d2, d3, d4], [e1, e2, e3, e4],
ifp, ifnotp) ifp, ifnotp)
@property @property
def _stream_size(self): def stream_size(self):
if self.finalized:
return self._size
x = DelayVar("not_yet_determined") x = DelayVar("not_yet_determined")
self.size_delayvars.append(x) self.size_delayvars.append(x)
return x return x
def finalize_code(self): def finalize_code(self):
self.finalized = True
for dv in self.size_delayvars: for dv in self.size_delayvars:
dv.val = self.stream_size dv.val = self._size
def to_inject(self):
return {self.prefix + '_stream_get': self._stream_get,
self.prefix + '_stream_get_v2': self._stream_get_v2,
self.prefix + '_stream_get_v4': self._stream_get_v4,
self.prefix + '_stream_size': self._stream_size}
def pack(self, _out_file_ = None, **kwargs): def pack(self, _out_file_ = None, **kwargs):
""" """
Evaluates all statements in the context of **kwargs. Take this code, Evaluates all statements in the context of **kwargs. Take this code,
presumably inside a PTX func:: presumably inside a PTX func::
>>> ex_stream_get(regA, reg1, 'sum([x+frob for x in xyz.things])') >>> ex.get(regA, reg1, 'sum([x+frob for x in xyz.things])')
To pack this into a struct, call this method on an instance: To pack this into a struct, call this method on an instance:
@ -1093,4 +1074,3 @@ class DataStream(PTXFragment):
for exp in cell.texp.exprlist[1:]: for exp in cell.texp.exprlist[1:]:
print '%12s %s' % ('', exp) print '%12s %s' % ('', exp)