instmethod decorator: another hack (to get around ctx.ptx.instances[])

This commit is contained in:
Steven Robertson 2010-09-08 13:12:46 -04:00
parent 094890c324
commit 1f7b00b61e
4 changed files with 45 additions and 29 deletions

View File

@ -7,7 +7,7 @@ Various micro-benchmarks and other experiments.
import numpy as np import numpy as np
import pycuda.autoinit import pycuda.autoinit
import pycuda.driver as cuda import pycuda.driver as cuda
from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func, instmethod
from cuburnlib.cuda import LaunchContext from cuburnlib.cuda import LaunchContext
from cuburnlib.device_code import MWCRNG from cuburnlib.device_code import MWCRNG
@ -104,7 +104,7 @@ class L2WriteCombining(PTXTest):
op.setp.ge.u32(p_done, x, 2) op.setp.ge.u32(p_done, x, 2)
op.bra.uni(l2_restart, ifnotp=p_done) op.bra.uni(l2_restart, ifnotp=p_done)
@instmethod
def call(self, ctx): def call(self, ctx):
scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64) scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F') times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
@ -137,7 +137,7 @@ def main():
ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1), ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1),
tests=True) tests=True)
ctx.compile(verbose=3) ctx.compile(verbose=3)
ctx.ptx.instances[L2WriteCombining].call(ctx) L2WriteCombining.call(ctx)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -130,6 +130,7 @@ class IterThread(PTXTest):
std.store_per_thread(g_num_rounds, num_rounds) std.store_per_thread(g_num_rounds, num_rounds)
std.store_per_thread(g_num_writes, num_writes) std.store_per_thread(g_num_writes, num_writes)
@instmethod
def upload_cp_stream(self, ctx, cp_stream, num_cps): def upload_cp_stream(self, ctx, cp_stream, num_cps):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array') cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
assert len(cp_stream) <= cp_array_l, "Stream too big!" assert len(cp_stream) <= cp_array_l, "Stream too big!"
@ -139,6 +140,7 @@ class IterThread(PTXTest):
cuda.memset_d32(num_cps_dp, num_cps, 1) cuda.memset_d32(num_cps_dp, num_cps, 1)
self.cps_uploaded = True self.cps_uploaded = True
@instmethod
def call(self, ctx): def call(self, ctx):
if not self.cps_uploaded: if not self.cps_uploaded:
raise Error("Cannot call IterThread before uploading CPs") raise Error("Cannot call IterThread before uploading CPs")

View File

@ -500,6 +500,9 @@ class PTXFragment(object):
An object containing PTX DSL functions. The object, and all its An object containing PTX DSL functions. The object, and all its
dependencies, will be instantiated by a PTX module. Each object will be dependencies, will be instantiated by a PTX module. Each object will be
bound to the name given by ``shortname`` in the DSL namespace. bound to the name given by ``shortname`` in the DSL namespace.
Because of the instantiation weirdness, use the instmethod decorator on
instance methods that will be called from regular Python code.
""" """
# Name under which to make this code available in ptx_funcs # Name under which to make this code available in ptx_funcs
@ -575,6 +578,17 @@ class PTXFragment(object):
""" """
pass pass
def instmethod(func):
"""
Wrapper to allow instances to be retrieved from an active context. Use it
on methods which depend on state created during a compilation phase, but
are intended to be called from normal Python code.
"""
def wrap(cls, ctx, *args, **kwargs):
inst = ctx.ptx.instances[cls]
func(inst, ctx, *args, **kwargs)
return classmethod(wrap)
class PTXEntryPoint(PTXFragment): class PTXEntryPoint(PTXFragment):
# Human-readable entry point name # Human-readable entry point name
name = "" name = ""
@ -591,6 +605,7 @@ class PTXEntryPoint(PTXFragment):
""" """
raise NotImplementedError raise NotImplementedError
@instmethod
def call(self, ctx): def call(self, ctx):
""" """
Calls the entry point on the device. Haven't worked out the details Calls the entry point on the device. Haven't worked out the details
@ -819,7 +834,6 @@ class PTXModule(object):
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
enumerate(self.source.split('\n'))]) enumerate(self.source.split('\n'))])
def _flatten(val): def _flatten(val):
if isinstance(val, (list, tuple)): if isinstance(val, (list, tuple)):
return ''.join(map(_flatten, val)) return ''.join(map(_flatten, val))
@ -883,7 +897,7 @@ class DataStream(PTXFragment):
>>> class ExampleDataStream(DataStream): >>> class ExampleDataStream(DataStream):
>>> shortname = "ex" >>> shortname = "ex"
Inside DSL functions, you can "retrieve" arbitrary Python expressions from Inside DSL functions, you can retrieve arbitrary Python expressions from
the data stream. the data stream.
>>> @ptx_func >>> @ptx_func
@ -892,22 +906,17 @@ class DataStream(PTXFragment):
>>> op.mov.u32(regA, some_device_allocation_base_address) >>> op.mov.u32(regA, some_device_allocation_base_address)
>>> # From the structure at the base address in 'regA', load the value >>> # From the structure at the base address in 'regA', load the value
>>> # of 'ctx.nthreads' into reg1 >>> # of 'ctx.nthreads' into reg1
>>> ex.get(regA, reg1, 'ctx.nthreads') >>> ex.get(regA, reg1, 'ctx.nthreads+padding')
The expressions will be stored as strings and mapped to particular The expressions will be stored as strings and mapped to particular
positions in the struct. Later, the expressions will be evaluated and positions in the struct. Later, the expressions will be evaluated and
coerced into a type matching the destination register: coerced into a type matching the destination register:
>>> # Fish the instance holding the data stream from the compiled module >>> data = ExampleDataStream.pack(ctx, padding=4)
>>> ex_stream = launch_context.ptx.instances[ExampleDataStream]
>>> # Evaluate the expressions in the current namespace, augmented with the
>>> # supplied objects
>>> data = ex_stream.pack(ctx=launch_context)
Expressions will be aligned and may be reused in such a way as to minimize Expressions will be aligned and may be reused in such a way as to minimize
access times when taking device caching into account. This also implies access times when taking device caching into account. This also implies
that the evaluated expressions should not modify any state, but that should that the evaluated expressions should not modify any state.
be obvious, no?
>>> @ptx_func >>> @ptx_func
>>> def example_func_2(): >>> def example_func_2():
@ -1034,7 +1043,8 @@ class DataStream(PTXFragment):
for dv in self.size_delayvars: for dv in self.size_delayvars:
dv.val = self._size dv.val = self._size
def pack(self, _out_file_ = None, **kwargs): @instmethod
def pack(self, ctx, _out_file_ = None, **kwargs):
""" """
Evaluates all statements in the context of **kwargs. Take this code, Evaluates all statements in the context of **kwargs. Take this code,
presumably inside a PTX func:: presumably inside a PTX func::
@ -1043,25 +1053,31 @@ class DataStream(PTXFragment):
To pack this into a struct, call this method on an instance: To pack this into a struct, call this method on an instance:
>>> ex_stream = launch_context.ptx.instances[ExampleDataStream] >>> data = ExampleDataStream.pack(ctx, frob=4, xyz=xyz)
>>> data = ex_stream.pack(frob=4, xyz=xyz)
This evaluates each Python expression from the stream with the provided This evaluates each Python expression from the stream with the provided
arguments as locals, coerces it to the appropriate type, and returns arguments as locals, coerces it to the appropriate type, and returns
the resulting structure as a string. the resulting structure as a string.
The supplied LaunchContext is added to the namespace as ``ctx`` by
default. To supress, this, override ``ctx`` in the keyword arguments:
>>> data = ExampleDataStream.pack(ctx, frob=5, xyz=xyz, ctx=None)
""" """
out = StringIO() out = StringIO()
self.pack_into(out, kwargs) cls.pack_into(out, kwargs)
return out.read() return out.read()
def pack_into(self, outfile, **kwargs): @instmethod
def pack_into(self, ctx, outfile, **kwargs):
""" """
Like pack(), but write data to a file-like object at the file's current Like pack(), but write data to a file-like object at the file's current
offset instead of returning it as a string. offset instead of returning it as a string.
>>> ex_stream.pack_into(strio_inst, frob=4, xyz=thing) >>> ex_stream.pack_into(ctx, strio_inst, frob=4, xyz=thing)
>>> ex_stream.pack_into(strio_inst, frob=6, xyz=another_thing) >>> ex_stream.pack_into(ctx, strio_inst, frob=6, xyz=another_thing)
""" """
kwargs.setdefault('ctx', ctx)
for offset, size, texp in self.cells: for offset, size, texp in self.cells:
if texp: if texp:
type = texp.type type = texp.type
@ -1071,7 +1087,8 @@ class DataStream(PTXFragment):
vals = [] vals = []
outfile.write(struct.pack(type, *vals)) outfile.write(struct.pack(type, *vals))
def print_record(self): @instmethod
def print_record(self, ctx):
for cell in self.cells: for cell in self.cells:
if cell.texp is None: if cell.texp is None:
print '%3d %2d --' % (cell.offset, cell.size) print '%3d %2d --' % (cell.offset, cell.size)

View File

@ -35,10 +35,9 @@ class Frame(pyflam3.Frame):
"Distribution of a CP across multiple CTAs not yet done") "Distribution of a CP across multiple CTAs not yet done")
# Interpolate each time step, calculate per-step variables, and pack # Interpolate each time step, calculate per-step variables, and pack
# into the stream # into the stream
cp_streamer = ctx.ptx.instances[CPDataStream]
stream = StringIO() stream = StringIO()
print "Data stream contents:" print "Data stream contents:"
cp_streamer.print_record() CPDataStream.print_record(ctx)
tcp = BaseGenome() tcp = BaseGenome()
for batch_idx in range(cp.nbatches): for batch_idx in range(cp.nbatches):
for time_idx in range(cp.ntemporal_samples): for time_idx in range(cp.ntemporal_samples):
@ -51,10 +50,8 @@ class Frame(pyflam3.Frame):
cp.width * cp.height) / ( cp.width * cp.height) / (
cp.nbatches * cp.ntemporal_samples) cp.nbatches * cp.ntemporal_samples)
cp_streamer.pack_into(stream, CPDataStream.pack_into(ctx, stream,
frame=self, frame=self, cp=tcp, cp_idx=idx)
cp=tcp,
cp_idx=idx)
stream.seek(0) stream.seek(0)
return (stream.read(), cp.nbatches * cp.ntemporal_samples) return (stream.read(), cp.nbatches * cp.ntemporal_samples)
@ -108,8 +105,8 @@ class Animation(object):
# TODO: allow animation-long override of certain parameters (size, etc) # TODO: allow animation-long override of certain parameters (size, etc)
cp_stream, num_cps = self.frame.pack_stream(self.ctx, time) cp_stream, num_cps = self.frame.pack_stream(self.ctx, time)
iter_thread = self.ctx.ptx.instances[IterThread] iter_thread = self.ctx.ptx.instances[IterThread]
iter_thread.upload_cp_stream(self.ctx, cp_stream, num_cps) IterThread.upload_cp_stream(self.ctx, cp_stream, num_cps)
iter_thread.call(self.ctx) IterThread.call(self.ctx)
class Features(object): class Features(object):
""" """