mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Switch from to_inject() to object insertion. One less kludge to deal with.
This commit is contained in:
parent
ada0fe20c7
commit
e03f20392d
1
TODO
1
TODO
@ -39,6 +39,7 @@ Things to do (rather severely incomplete):
|
||||
- Test effects on quality by masking off writes on all but one lane and
|
||||
boosting the sample density to compensate (muuuuuch later on)
|
||||
- DE
|
||||
- Clean up code (particularly DSL stuff incl. injector)
|
||||
|
||||
Things to test:
|
||||
|
||||
|
@ -24,7 +24,7 @@ class IterThread(PTXTest):
|
||||
@ptx_func
|
||||
def module_setup(self):
|
||||
mem.global_.u32('g_cp_array',
|
||||
cp_stream_size*features.max_ntemporal_samples)
|
||||
cp.stream_size*features.max_ntemporal_samples)
|
||||
mem.global_.u32('g_num_cps')
|
||||
# TODO move into debug statement
|
||||
mem.global_.u32('g_num_rounds', ctx.threads)
|
||||
@ -40,10 +40,10 @@ class IterThread(PTXTest):
|
||||
op.mov.u32(num_writes, 0)
|
||||
|
||||
# TODO: MWC float output types
|
||||
mwc_next_f32_01(x_coord)
|
||||
mwc_next_f32_01(y_coord)
|
||||
mwc_next_f32_01(color_coord)
|
||||
mwc_next_f32_01(alpha_coord)
|
||||
mwc.next_f32_01(x_coord)
|
||||
mwc.next_f32_01(y_coord)
|
||||
mwc.next_f32_01(color_coord)
|
||||
mwc.next_f32_01(alpha_coord)
|
||||
|
||||
# Registers are hard to come by. To avoid having to track both the count
|
||||
# of samples processed and the number of samples to generate,
|
||||
@ -81,13 +81,13 @@ class IterThread(PTXTest):
|
||||
|
||||
with block('Load CP address'):
|
||||
op.mov.u32(cpA, g_cp_array)
|
||||
op.mad.lo.u32(cpA, cp_idx, cp_stream_size, cpA)
|
||||
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
|
||||
|
||||
with block('Increment CP index, load num_samples (unless in fuse)'):
|
||||
reg.pred('p_not_in_fuse')
|
||||
op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
|
||||
op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
|
||||
cp_stream_get(cpA, num_samples, 'samples_per_thread',
|
||||
cp.get(cpA, num_samples, 'samples_per_thread',
|
||||
ifp=p_not_in_fuse)
|
||||
|
||||
label('fuse_loop_start')
|
||||
@ -127,8 +127,8 @@ class IterThread(PTXTest):
|
||||
|
||||
label('all_cps_done')
|
||||
# TODO this is for testing, move it to a debug statement
|
||||
store_per_thread(g_num_rounds, num_rounds)
|
||||
store_per_thread(g_num_writes, num_writes)
|
||||
std.store_per_thread(g_num_rounds, num_rounds)
|
||||
std.store_per_thread(g_num_writes, num_writes)
|
||||
|
||||
def upload_cp_stream(self, ctx, cp_stream, num_cps):
|
||||
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
|
||||
@ -152,6 +152,8 @@ class IterThread(PTXTest):
|
||||
print "Writes:", writes
|
||||
|
||||
class MWCRNG(PTXFragment):
|
||||
shortname = "mwc"
|
||||
|
||||
def __init__(self):
|
||||
self.rand = np.random
|
||||
self.threads_ready = 0
|
||||
@ -171,7 +173,7 @@ class MWCRNG(PTXFragment):
|
||||
reg.u32('mwc_st mwc_mult mwc_car')
|
||||
with block('Load MWC multipliers and states'):
|
||||
reg.u32('mwc_off mwc_addr')
|
||||
get_gtid(mwc_off)
|
||||
std.get_gtid(mwc_off)
|
||||
op.mov.u32(mwc_addr, mwc_rng_mults)
|
||||
op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr)
|
||||
op.ld.global_.u32(mwc_mult, addr(mwc_addr))
|
||||
@ -184,7 +186,7 @@ class MWCRNG(PTXFragment):
|
||||
def entry_teardown(self):
|
||||
with block('Save MWC states'):
|
||||
reg.u32('mwc_off mwc_addr')
|
||||
get_gtid(mwc_off)
|
||||
std.get_gtid(mwc_off)
|
||||
op.mov.u32(mwc_addr, mwc_rng_state)
|
||||
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
|
||||
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
|
||||
@ -219,11 +221,6 @@ class MWCRNG(PTXFragment):
|
||||
op.cvt.rn.f32.s32(dst_reg, mwc_st)
|
||||
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
|
||||
|
||||
def to_inject(self):
|
||||
return dict(mwc_next_b32=self.next_b32,
|
||||
mwc_next_f32_01=self.next_f32_01,
|
||||
mwc_next_f32_11=self.next_f32_11)
|
||||
|
||||
def device_init(self, ctx):
|
||||
if self.threads_ready >= ctx.threads:
|
||||
# Already set up enough random states, don't push again
|
||||
@ -275,7 +272,7 @@ class MWCRNGTest(PTXTest):
|
||||
reg.pred('p')
|
||||
op.mov.u32(loopct, self.rounds)
|
||||
label('loopstart')
|
||||
mwc_next_b32(addend)
|
||||
mwc.next_b32(addend)
|
||||
op.cvt.u64.u32(addl, addend)
|
||||
op.add.u64(sum, sum, addl)
|
||||
op.sub.u32(loopct, loopct, 1)
|
||||
@ -284,7 +281,7 @@ class MWCRNGTest(PTXTest):
|
||||
|
||||
with block('Store sum and state'):
|
||||
reg.u32('adr offset')
|
||||
get_gtid(offset)
|
||||
std.get_gtid(offset)
|
||||
op.mov.u32(adr, mwc_rng_test_sums)
|
||||
op.mad.lo.u32(adr, offset, 8, adr)
|
||||
op.st.global_.u64(addr(adr), sum)
|
||||
@ -331,5 +328,5 @@ class CameraCoordTransform(PTXFragment):
|
||||
|
||||
class CPDataStream(DataStream):
|
||||
"""DataStream which stores the control points."""
|
||||
prefix = 'cp'
|
||||
shortname = 'cp'
|
||||
|
||||
|
112
cuburnlib/ptx.py
112
cuburnlib/ptx.py
@ -493,22 +493,14 @@ class Comment(object):
|
||||
|
||||
class PTXFragment(object):
|
||||
"""
|
||||
An object containing PTX DSL functions.
|
||||
|
||||
In cuflame, several different versions of a given function may be
|
||||
regenerated in rapid succession
|
||||
|
||||
The final compilation pass is guaranteed to have all "tuned" values fixed
|
||||
in their final values for the stream.
|
||||
|
||||
Template code will be processed recursively until all "{{" instances have
|
||||
been replaced, using the same namespace each time.
|
||||
|
||||
Note that any method which does not depend on 'ctx' can be replaced with
|
||||
an instance of the appropriate return type. So, for example, the 'deps'
|
||||
property can be a flat list instead of a function.
|
||||
An object containing PTX DSL functions. The object, and all its
|
||||
dependencies, will be instantiated by a PTX module. Each object will be
|
||||
bound to the name given by ``shortname`` in the DSL namespace.
|
||||
"""
|
||||
|
||||
# Name under which to make this code available in ptx_funcs
|
||||
shortname = None
|
||||
|
||||
def deps(self):
|
||||
"""
|
||||
Returns a list of PTXFragment types on which this object depends
|
||||
@ -517,15 +509,6 @@ class PTXFragment(object):
|
||||
"""
|
||||
return [_PTXStdLib]
|
||||
|
||||
def to_inject(self):
|
||||
"""
|
||||
Returns a dict of items to add to the DSL namespace. The namespace will
|
||||
be assembled in dependency order before any ptx_funcs are called.
|
||||
|
||||
This is only called once per PTXModule (== once per instance).
|
||||
"""
|
||||
return {}
|
||||
|
||||
def module_setup(self):
|
||||
"""
|
||||
PTX function to declare things at module scope. It's a PTX syntax error
|
||||
@ -624,6 +607,7 @@ class PTXTest(PTXEntryPoint):
|
||||
pass
|
||||
|
||||
class _PTXStdLib(PTXFragment):
|
||||
shortname = "std"
|
||||
def __init__(self, block):
|
||||
# Only module that gets the privilege of seeing 'block' directly.
|
||||
self.block = block
|
||||
@ -639,7 +623,7 @@ class _PTXStdLib(PTXFragment):
|
||||
self.block.code(prefix='.target sm_20', semi=False)
|
||||
|
||||
@ptx_func
|
||||
def _get_gtid(self, dst):
|
||||
def get_gtid(self, dst):
|
||||
"""
|
||||
Get the global thread ID (the position of this thread in a grid of
|
||||
blocks of threads). Notably, this assumes that both grid and block are
|
||||
@ -661,16 +645,17 @@ class _PTXStdLib(PTXFragment):
|
||||
op.mov.b32(dst, gtid)
|
||||
|
||||
@ptx_func
|
||||
def _store_per_thread(self, base, val):
|
||||
def store_per_thread(self, base, val):
|
||||
"""Store b32 at `base+gtid*4`. Super-common debug pattern."""
|
||||
with block("Per-thread store of %s" % str(val)):
|
||||
reg.u32('spt_base spt_offset')
|
||||
op.mov.u32(spt_base, base)
|
||||
get_gtid(spt_offset)
|
||||
self.get_gtid(spt_offset)
|
||||
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
|
||||
op.st.b32(addr(spt_base), val)
|
||||
|
||||
def to_inject(self):
|
||||
# Set up the initial namespace
|
||||
return dict(
|
||||
_block=self.block,
|
||||
block=Block(self.block),
|
||||
@ -680,9 +665,7 @@ class _PTXStdLib(PTXFragment):
|
||||
addr=Mem.addr,
|
||||
vec=Mem.vec,
|
||||
label=_LabelFactory(self.block),
|
||||
comment=Comment(self.block),
|
||||
get_gtid=self._get_gtid,
|
||||
store_per_thread=self._store_per_thread)
|
||||
comment=Comment(self.block))
|
||||
|
||||
class PTXModule(object):
|
||||
"""
|
||||
@ -713,9 +696,11 @@ class PTXModule(object):
|
||||
self.tests = tests
|
||||
|
||||
inject = dict(inject)
|
||||
self._safeupdate(inject, {'module': self})
|
||||
inject.update(insts[_PTXStdLib].to_inject())
|
||||
self._safeupdate(inject, 'module', self)
|
||||
for inst in all_deps:
|
||||
self._safeupdate(inject, inst.to_inject())
|
||||
if inst.shortname:
|
||||
self._safeupdate(inject, inst.shortname, inst)
|
||||
[block.inject(k, v) for k, v in inject.items()]
|
||||
|
||||
self.__needs_recompilation = True
|
||||
@ -749,11 +734,9 @@ class PTXModule(object):
|
||||
map(rec, unsorted_instances)
|
||||
return sorted(unsorted_instances, key=seen.get)
|
||||
|
||||
def _safeupdate(self, dst, src):
|
||||
"""dst.update(src), but no duplicates allowed"""
|
||||
non_uniq = [k for k in src if k in dst]
|
||||
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
|
||||
dst.update(src)
|
||||
def _safeupdate(self, dst, k, v):
|
||||
if k in dst: raise KeyError("Duplicate key %s" % k)
|
||||
dst[k] = v
|
||||
|
||||
def deptrace(self, block, entries, build_tests):
|
||||
instances = {_PTXStdLib: _PTXStdLib(block)}
|
||||
@ -796,6 +779,7 @@ class PTXModule(object):
|
||||
def assemble(self, block, all_deps, entry_deps):
|
||||
# Rebind to local namespace to allow proper retrieval
|
||||
_block = block
|
||||
|
||||
for inst in all_deps:
|
||||
inst.module_setup()
|
||||
|
||||
@ -886,11 +870,11 @@ class DataStream(PTXFragment):
|
||||
variable positions determined at runtime. The resulting structure had to be
|
||||
read strictly sequentially to be parsed, hence the name "stream".)
|
||||
|
||||
Subclass this and give it a prefix, then depend on the subclass in your PTX
|
||||
fragments.
|
||||
Subclass this and give it a shortname, then depend on the subclass in your
|
||||
PTX fragments. An instance-based approach is under consideration.
|
||||
|
||||
>>> class ExampleDataStream(DataStream):
|
||||
>>> prefix = 'ex'
|
||||
>>> shortname = "ex"
|
||||
|
||||
Inside DSL functions, you can "retrieve" arbitrary Python expressions from
|
||||
the data stream.
|
||||
@ -901,7 +885,7 @@ class DataStream(PTXFragment):
|
||||
>>> op.mov.u32(regA, some_device_allocation_base_address)
|
||||
>>> # From the structure at the base address in 'regA', load the value
|
||||
>>> # of 'ctx.nthreads' into reg1
|
||||
>>> ex_stream_get(regA, reg1, 'ctx.nthreads')
|
||||
>>> ex.get(regA, reg1, 'ctx.nthreads')
|
||||
|
||||
The expressions will be stored as strings and mapped to particular
|
||||
positions in the struct. Later, the expressions will be evaluated and
|
||||
@ -922,32 +906,32 @@ class DataStream(PTXFragment):
|
||||
>>> def example_func_2():
|
||||
>>> reg.u32('reg1 reg2')
|
||||
>>> reg.f32('regf')
|
||||
>>> ex_stream_get(regA, reg1, 'ctx.nthreads * 2')
|
||||
>>> ex.get(regA, reg1, 'ctx.nthreads * 2')
|
||||
>>> # Same expression, so load comes from same memory location
|
||||
>>> ex_stream_get(regA, reg2, 'ctx.nthreads * 2')
|
||||
>>> ex.get(regA, reg2, 'ctx.nthreads * 2')
|
||||
>>> # Vector loads are pre-coerced, so you can mix types
|
||||
>>> ex_stream_get_v2(regA, reg1, '4', regf, '5.5')
|
||||
>>> ex.get_v2(regA, reg1, '4', regf, '5.5')
|
||||
|
||||
You can even do device allocations in the file, using the post-finalized
|
||||
variable '[prefix]_stream_size'. It's a StrVar, so if you do any operations
|
||||
to it make sure you write them as a list of strings for PTX to handle (I
|
||||
know, it's a drag; it might be fixed later):
|
||||
variable '[prefix]_stream_size'. It's a DelayVar; simple things like
|
||||
multiplying by a number work (as long as the DelayVar comes first), but
|
||||
fancy things like multiplying two DelayVars aren't implemented yet.
|
||||
|
||||
>>> class Whatever(PTXFragment):
|
||||
>>> @ptx_func
|
||||
>>> def module_setup(self):
|
||||
>>> mem.global_.u32('ex_streams', [1000, '*', ex_stream_size])
|
||||
>>> mem.global_.u32('ex_streams', ex.stream_size*1000)
|
||||
"""
|
||||
# Must be at least as large as the largest load (.v4.u32 = 16)
|
||||
alignment = 16
|
||||
prefix = 'Subclass this'
|
||||
|
||||
def __init__(self):
|
||||
self.texp_map = {}
|
||||
self.cells = []
|
||||
self.stream_size = 0
|
||||
self._size = 0
|
||||
self.free = {}
|
||||
self.size_delayvars = []
|
||||
self.finalized = False
|
||||
|
||||
_types = dict(s8='b', u8='B', s16='h', u16='H', s32='i', u32='I', f32='f',
|
||||
s64='l', u64='L', f64='d')
|
||||
@ -973,8 +957,8 @@ class DataStream(PTXFragment):
|
||||
# No aligned free cells, allocate a new `align`-byte free cell
|
||||
assert alloc not in self.free
|
||||
self.free[alloc] = idx = len(self.cells)
|
||||
self.cells.append(_DataCell(self.stream_size, alloc, None))
|
||||
self.stream_size += alloc
|
||||
self.cells.append(_DataCell(self._size, alloc, None))
|
||||
self._size += alloc
|
||||
# Overwrite the free cell at `idx` with texp
|
||||
assert self.cells[idx].texp is None
|
||||
offset = self.cells[idx].offset
|
||||
@ -1015,43 +999,40 @@ class DataStream(PTXFragment):
|
||||
op._call(opname, dregs, addr(areg, offset), ifp=ifp, ifnotp=ifnotp)
|
||||
|
||||
@ptx_func
|
||||
def _stream_get(self, areg, dreg, expr, ifp=None, ifnotp=None):
|
||||
def get(self, areg, dreg, expr, ifp=None, ifnotp=None):
|
||||
self._stream_get_internal(areg, [dreg], [expr], ifp, ifnotp)
|
||||
|
||||
@ptx_func
|
||||
def _stream_get_v2(self, areg, dreg1, expr1, dreg2, expr2,
|
||||
ifp=None, ifnotp=None):
|
||||
def get_v2(self, areg, dreg1, expr1, dreg2, expr2, ifp=None, ifnotp=None):
|
||||
self._stream_get_internal(areg, [dreg1, dreg2], [expr1, expr2],
|
||||
ifp, ifnotp)
|
||||
|
||||
# The interleaved signature makes calls easier to read
|
||||
@ptx_func
|
||||
def _stream_get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4,
|
||||
ifp=None, ifnotp=None):
|
||||
def get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4,
|
||||
ifp=None, ifnotp=None):
|
||||
self._stream_get_internal(areg, [d1, d2, d3, d4], [e1, e2, e3, e4],
|
||||
ifp, ifnotp)
|
||||
|
||||
@property
|
||||
def _stream_size(self):
|
||||
def stream_size(self):
|
||||
if self.finalized:
|
||||
return self._size
|
||||
x = DelayVar("not_yet_determined")
|
||||
self.size_delayvars.append(x)
|
||||
return x
|
||||
|
||||
def finalize_code(self):
|
||||
self.finalized = True
|
||||
for dv in self.size_delayvars:
|
||||
dv.val = self.stream_size
|
||||
|
||||
def to_inject(self):
|
||||
return {self.prefix + '_stream_get': self._stream_get,
|
||||
self.prefix + '_stream_get_v2': self._stream_get_v2,
|
||||
self.prefix + '_stream_get_v4': self._stream_get_v4,
|
||||
self.prefix + '_stream_size': self._stream_size}
|
||||
dv.val = self._size
|
||||
|
||||
def pack(self, _out_file_ = None, **kwargs):
|
||||
"""
|
||||
Evaluates all statements in the context of **kwargs. Take this code,
|
||||
presumably inside a PTX func::
|
||||
|
||||
>>> ex_stream_get(regA, reg1, 'sum([x+frob for x in xyz.things])')
|
||||
>>> ex.get(regA, reg1, 'sum([x+frob for x in xyz.things])')
|
||||
|
||||
To pack this into a struct, call this method on an instance:
|
||||
|
||||
@ -1093,4 +1074,3 @@ class DataStream(PTXFragment):
|
||||
for exp in cell.texp.exprlist[1:]:
|
||||
print '%12s %s' % ('', exp)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user