Switch from to_inject() to object insertion. One less kludge to deal with.

This commit is contained in:
Steven Robertson 2010-09-06 16:09:37 -04:00
parent ada0fe20c7
commit e03f20392d
3 changed files with 63 additions and 85 deletions

1
TODO
View File

@ -39,6 +39,7 @@ Things to do (rather severely incomplete):
- Test effects on quality by masking off writes on all but one lane and
boosting the sample density to compensate (muuuuuch later on)
- DE
- Clean up code (particularly DSL stuff incl. injector)
Things to test:

View File

@ -24,7 +24,7 @@ class IterThread(PTXTest):
@ptx_func
def module_setup(self):
mem.global_.u32('g_cp_array',
cp_stream_size*features.max_ntemporal_samples)
cp.stream_size*features.max_ntemporal_samples)
mem.global_.u32('g_num_cps')
# TODO move into debug statement
mem.global_.u32('g_num_rounds', ctx.threads)
@ -40,10 +40,10 @@ class IterThread(PTXTest):
op.mov.u32(num_writes, 0)
# TODO: MWC float output types
mwc_next_f32_01(x_coord)
mwc_next_f32_01(y_coord)
mwc_next_f32_01(color_coord)
mwc_next_f32_01(alpha_coord)
mwc.next_f32_01(x_coord)
mwc.next_f32_01(y_coord)
mwc.next_f32_01(color_coord)
mwc.next_f32_01(alpha_coord)
# Registers are hard to come by. To avoid having to track both the count
# of samples processed and the number of samples to generate,
@ -81,13 +81,13 @@ class IterThread(PTXTest):
with block('Load CP address'):
op.mov.u32(cpA, g_cp_array)
op.mad.lo.u32(cpA, cp_idx, cp_stream_size, cpA)
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
with block('Increment CP index, load num_samples (unless in fuse)'):
reg.pred('p_not_in_fuse')
op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
cp_stream_get(cpA, num_samples, 'samples_per_thread',
cp.get(cpA, num_samples, 'samples_per_thread',
ifp=p_not_in_fuse)
label('fuse_loop_start')
@ -127,8 +127,8 @@ class IterThread(PTXTest):
label('all_cps_done')
# TODO this is for testing, move it to a debug statement
store_per_thread(g_num_rounds, num_rounds)
store_per_thread(g_num_writes, num_writes)
std.store_per_thread(g_num_rounds, num_rounds)
std.store_per_thread(g_num_writes, num_writes)
def upload_cp_stream(self, ctx, cp_stream, num_cps):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
@ -152,6 +152,8 @@ class IterThread(PTXTest):
print "Writes:", writes
class MWCRNG(PTXFragment):
shortname = "mwc"
def __init__(self):
self.rand = np.random
self.threads_ready = 0
@ -171,7 +173,7 @@ class MWCRNG(PTXFragment):
reg.u32('mwc_st mwc_mult mwc_car')
with block('Load MWC multipliers and states'):
reg.u32('mwc_off mwc_addr')
get_gtid(mwc_off)
std.get_gtid(mwc_off)
op.mov.u32(mwc_addr, mwc_rng_mults)
op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr)
op.ld.global_.u32(mwc_mult, addr(mwc_addr))
@ -184,7 +186,7 @@ class MWCRNG(PTXFragment):
def entry_teardown(self):
with block('Save MWC states'):
reg.u32('mwc_off mwc_addr')
get_gtid(mwc_off)
std.get_gtid(mwc_off)
op.mov.u32(mwc_addr, mwc_rng_state)
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
@ -219,11 +221,6 @@ class MWCRNG(PTXFragment):
op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
def to_inject(self):
return dict(mwc_next_b32=self.next_b32,
mwc_next_f32_01=self.next_f32_01,
mwc_next_f32_11=self.next_f32_11)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads:
# Already set up enough random states, don't push again
@ -275,7 +272,7 @@ class MWCRNGTest(PTXTest):
reg.pred('p')
op.mov.u32(loopct, self.rounds)
label('loopstart')
mwc_next_b32(addend)
mwc.next_b32(addend)
op.cvt.u64.u32(addl, addend)
op.add.u64(sum, sum, addl)
op.sub.u32(loopct, loopct, 1)
@ -284,7 +281,7 @@ class MWCRNGTest(PTXTest):
with block('Store sum and state'):
reg.u32('adr offset')
get_gtid(offset)
std.get_gtid(offset)
op.mov.u32(adr, mwc_rng_test_sums)
op.mad.lo.u32(adr, offset, 8, adr)
op.st.global_.u64(addr(adr), sum)
@ -331,5 +328,5 @@ class CameraCoordTransform(PTXFragment):
class CPDataStream(DataStream):
"""DataStream which stores the control points."""
prefix = 'cp'
shortname = 'cp'

View File

@ -493,22 +493,14 @@ class Comment(object):
class PTXFragment(object):
"""
An object containing PTX DSL functions.
In cuflame, several different versions of a given function may be
regenerated in rapid succession
The final compilation pass is guaranteed to have all "tuned" values fixed
in their final values for the stream.
Template code will be processed recursively until all "{{" instances have
been replaced, using the same namespace each time.
Note that any method which does not depend on 'ctx' can be replaced with
an instance of the appropriate return type. So, for example, the 'deps'
property can be a flat list instead of a function.
An object containing PTX DSL functions. The object, and all its
dependencies, will be instantiated by a PTX module. Each object will be
bound to the name given by ``shortname`` in the DSL namespace.
"""
# Name under which to make this code available in ptx_funcs
shortname = None
def deps(self):
"""
Returns a list of PTXFragment types on which this object depends
@ -517,15 +509,6 @@ class PTXFragment(object):
"""
return [_PTXStdLib]
def to_inject(self):
"""
Returns a dict of items to add to the DSL namespace. The namespace will
be assembled in dependency order before any ptx_funcs are called.
This is only called once per PTXModule (== once per instance).
"""
return {}
def module_setup(self):
"""
PTX function to declare things at module scope. It's a PTX syntax error
@ -624,6 +607,7 @@ class PTXTest(PTXEntryPoint):
pass
class _PTXStdLib(PTXFragment):
shortname = "std"
def __init__(self, block):
# Only module that gets the privilege of seeing 'block' directly.
self.block = block
@ -639,7 +623,7 @@ class _PTXStdLib(PTXFragment):
self.block.code(prefix='.target sm_20', semi=False)
@ptx_func
def _get_gtid(self, dst):
def get_gtid(self, dst):
"""
Get the global thread ID (the position of this thread in a grid of
blocks of threads). Notably, this assumes that both grid and block are
@ -661,16 +645,17 @@ class _PTXStdLib(PTXFragment):
op.mov.b32(dst, gtid)
@ptx_func
def _store_per_thread(self, base, val):
def store_per_thread(self, base, val):
"""Store b32 at `base+gtid*4`. Super-common debug pattern."""
with block("Per-thread store of %s" % str(val)):
reg.u32('spt_base spt_offset')
op.mov.u32(spt_base, base)
get_gtid(spt_offset)
self.get_gtid(spt_offset)
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
op.st.b32(addr(spt_base), val)
def to_inject(self):
# Set up the initial namespace
return dict(
_block=self.block,
block=Block(self.block),
@ -680,9 +665,7 @@ class _PTXStdLib(PTXFragment):
addr=Mem.addr,
vec=Mem.vec,
label=_LabelFactory(self.block),
comment=Comment(self.block),
get_gtid=self._get_gtid,
store_per_thread=self._store_per_thread)
comment=Comment(self.block))
class PTXModule(object):
"""
@ -713,9 +696,11 @@ class PTXModule(object):
self.tests = tests
inject = dict(inject)
self._safeupdate(inject, {'module': self})
inject.update(insts[_PTXStdLib].to_inject())
self._safeupdate(inject, 'module', self)
for inst in all_deps:
self._safeupdate(inject, inst.to_inject())
if inst.shortname:
self._safeupdate(inject, inst.shortname, inst)
[block.inject(k, v) for k, v in inject.items()]
self.__needs_recompilation = True
@ -749,11 +734,9 @@ class PTXModule(object):
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
"""dst.update(src), but no duplicates allowed"""
non_uniq = [k for k in src if k in dst]
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
dst.update(src)
def _safeupdate(self, dst, k, v):
if k in dst: raise KeyError("Duplicate key %s" % k)
dst[k] = v
def deptrace(self, block, entries, build_tests):
instances = {_PTXStdLib: _PTXStdLib(block)}
@ -796,6 +779,7 @@ class PTXModule(object):
def assemble(self, block, all_deps, entry_deps):
# Rebind to local namespace to allow proper retrieval
_block = block
for inst in all_deps:
inst.module_setup()
@ -886,11 +870,11 @@ class DataStream(PTXFragment):
variable positions determined at runtime. The resulting structure had to be
read strictly sequentially to be parsed, hence the name "stream".)
Subclass this and give it a prefix, then depend on the subclass in your PTX
fragments.
Subclass this and give it a shortname, then depend on the subclass in your
PTX fragments. An instance-based approach is under consideration.
>>> class ExampleDataStream(DataStream):
>>> prefix = 'ex'
>>> shortname = "ex"
Inside DSL functions, you can "retrieve" arbitrary Python expressions from
the data stream.
@ -901,7 +885,7 @@ class DataStream(PTXFragment):
>>> op.mov.u32(regA, some_device_allocation_base_address)
>>> # From the structure at the base address in 'regA', load the value
>>> # of 'ctx.nthreads' into reg1
>>> ex_stream_get(regA, reg1, 'ctx.nthreads')
>>> ex.get(regA, reg1, 'ctx.nthreads')
The expressions will be stored as strings and mapped to particular
positions in the struct. Later, the expressions will be evaluated and
@ -922,32 +906,32 @@ class DataStream(PTXFragment):
>>> def example_func_2():
>>> reg.u32('reg1 reg2')
>>> reg.f32('regf')
>>> ex_stream_get(regA, reg1, 'ctx.nthreads * 2')
>>> ex.get(regA, reg1, 'ctx.nthreads * 2')
>>> # Same expression, so load comes from same memory location
>>> ex_stream_get(regA, reg2, 'ctx.nthreads * 2')
>>> ex.get(regA, reg2, 'ctx.nthreads * 2')
>>> # Vector loads are pre-coerced, so you can mix types
>>> ex_stream_get_v2(regA, reg1, '4', regf, '5.5')
>>> ex.get_v2(regA, reg1, '4', regf, '5.5')
You can even do device allocations in the file, using the post-finalized
variable '[prefix]_stream_size'. It's a StrVar, so if you do any operations
to it make sure you write them as a list of strings for PTX to handle (I
know, it's a drag; it might be fixed later):
variable '[prefix]_stream_size'. It's a DelayVar; simple things like
multiplying by a number work (as long as the DelayVar comes first), but
fancy things like multiplying two DelayVars aren't implemented yet.
>>> class Whatever(PTXFragment):
>>> @ptx_func
>>> def module_setup(self):
>>> mem.global_.u32('ex_streams', [1000, '*', ex_stream_size])
>>> mem.global_.u32('ex_streams', ex.stream_size*1000)
"""
# Must be at least as large as the largest load (.v4.u32 = 16)
alignment = 16
prefix = 'Subclass this'
def __init__(self):
self.texp_map = {}
self.cells = []
self.stream_size = 0
self._size = 0
self.free = {}
self.size_delayvars = []
self.finalized = False
_types = dict(s8='b', u8='B', s16='h', u16='H', s32='i', u32='I', f32='f',
s64='l', u64='L', f64='d')
@ -973,8 +957,8 @@ class DataStream(PTXFragment):
# No aligned free cells, allocate a new `align`-byte free cell
assert alloc not in self.free
self.free[alloc] = idx = len(self.cells)
self.cells.append(_DataCell(self.stream_size, alloc, None))
self.stream_size += alloc
self.cells.append(_DataCell(self._size, alloc, None))
self._size += alloc
# Overwrite the free cell at `idx` with texp
assert self.cells[idx].texp is None
offset = self.cells[idx].offset
@ -1015,43 +999,40 @@ class DataStream(PTXFragment):
op._call(opname, dregs, addr(areg, offset), ifp=ifp, ifnotp=ifnotp)
@ptx_func
def _stream_get(self, areg, dreg, expr, ifp=None, ifnotp=None):
def get(self, areg, dreg, expr, ifp=None, ifnotp=None):
self._stream_get_internal(areg, [dreg], [expr], ifp, ifnotp)
@ptx_func
def _stream_get_v2(self, areg, dreg1, expr1, dreg2, expr2,
ifp=None, ifnotp=None):
def get_v2(self, areg, dreg1, expr1, dreg2, expr2, ifp=None, ifnotp=None):
self._stream_get_internal(areg, [dreg1, dreg2], [expr1, expr2],
ifp, ifnotp)
# The interleaved signature makes calls easier to read
@ptx_func
def _stream_get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4,
def get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4,
ifp=None, ifnotp=None):
self._stream_get_internal(areg, [d1, d2, d3, d4], [e1, e2, e3, e4],
ifp, ifnotp)
@property
def _stream_size(self):
def stream_size(self):
if self.finalized:
return self._size
x = DelayVar("not_yet_determined")
self.size_delayvars.append(x)
return x
def finalize_code(self):
self.finalized = True
for dv in self.size_delayvars:
dv.val = self.stream_size
def to_inject(self):
return {self.prefix + '_stream_get': self._stream_get,
self.prefix + '_stream_get_v2': self._stream_get_v2,
self.prefix + '_stream_get_v4': self._stream_get_v4,
self.prefix + '_stream_size': self._stream_size}
dv.val = self._size
def pack(self, _out_file_ = None, **kwargs):
"""
Evaluates all statements in the context of **kwargs. Take this code,
presumably inside a PTX func::
>>> ex_stream_get(regA, reg1, 'sum([x+frob for x in xyz.things])')
>>> ex.get(regA, reg1, 'sum([x+frob for x in xyz.things])')
To pack this into a struct, call this method on an instance:
@ -1093,4 +1074,3 @@ class DataStream(PTXFragment):
for exp in cell.texp.exprlist[1:]:
print '%12s %s' % ('', exp)