mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
instmethod decorator: another hack (to get around ctx.ptx.instances[])
This commit is contained in:
parent
094890c324
commit
1f7b00b61e
6
bench.py
6
bench.py
@ -7,7 +7,7 @@ Various micro-benchmarks and other experiments.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pycuda.autoinit
|
import pycuda.autoinit
|
||||||
import pycuda.driver as cuda
|
import pycuda.driver as cuda
|
||||||
from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func
|
from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func, instmethod
|
||||||
from cuburnlib.cuda import LaunchContext
|
from cuburnlib.cuda import LaunchContext
|
||||||
from cuburnlib.device_code import MWCRNG
|
from cuburnlib.device_code import MWCRNG
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class L2WriteCombining(PTXTest):
|
|||||||
op.setp.ge.u32(p_done, x, 2)
|
op.setp.ge.u32(p_done, x, 2)
|
||||||
op.bra.uni(l2_restart, ifnotp=p_done)
|
op.bra.uni(l2_restart, ifnotp=p_done)
|
||||||
|
|
||||||
|
@instmethod
|
||||||
def call(self, ctx):
|
def call(self, ctx):
|
||||||
scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
|
scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
|
||||||
times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
|
times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
|
||||||
@ -137,7 +137,7 @@ def main():
|
|||||||
ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1),
|
ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1),
|
||||||
tests=True)
|
tests=True)
|
||||||
ctx.compile(verbose=3)
|
ctx.compile(verbose=3)
|
||||||
ctx.ptx.instances[L2WriteCombining].call(ctx)
|
L2WriteCombining.call(ctx)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -130,6 +130,7 @@ class IterThread(PTXTest):
|
|||||||
std.store_per_thread(g_num_rounds, num_rounds)
|
std.store_per_thread(g_num_rounds, num_rounds)
|
||||||
std.store_per_thread(g_num_writes, num_writes)
|
std.store_per_thread(g_num_writes, num_writes)
|
||||||
|
|
||||||
|
@instmethod
|
||||||
def upload_cp_stream(self, ctx, cp_stream, num_cps):
|
def upload_cp_stream(self, ctx, cp_stream, num_cps):
|
||||||
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
|
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
|
||||||
assert len(cp_stream) <= cp_array_l, "Stream too big!"
|
assert len(cp_stream) <= cp_array_l, "Stream too big!"
|
||||||
@ -139,6 +140,7 @@ class IterThread(PTXTest):
|
|||||||
cuda.memset_d32(num_cps_dp, num_cps, 1)
|
cuda.memset_d32(num_cps_dp, num_cps, 1)
|
||||||
self.cps_uploaded = True
|
self.cps_uploaded = True
|
||||||
|
|
||||||
|
@instmethod
|
||||||
def call(self, ctx):
|
def call(self, ctx):
|
||||||
if not self.cps_uploaded:
|
if not self.cps_uploaded:
|
||||||
raise Error("Cannot call IterThread before uploading CPs")
|
raise Error("Cannot call IterThread before uploading CPs")
|
||||||
|
@ -500,6 +500,9 @@ class PTXFragment(object):
|
|||||||
An object containing PTX DSL functions. The object, and all its
|
An object containing PTX DSL functions. The object, and all its
|
||||||
dependencies, will be instantiated by a PTX module. Each object will be
|
dependencies, will be instantiated by a PTX module. Each object will be
|
||||||
bound to the name given by ``shortname`` in the DSL namespace.
|
bound to the name given by ``shortname`` in the DSL namespace.
|
||||||
|
|
||||||
|
Because of the instantiation weirdness, use the instmethod decorator on
|
||||||
|
instance methods that will be called from regular Python code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Name under which to make this code available in ptx_funcs
|
# Name under which to make this code available in ptx_funcs
|
||||||
@ -575,6 +578,17 @@ class PTXFragment(object):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def instmethod(func):
|
||||||
|
"""
|
||||||
|
Wrapper to allow instances to be retrieved from an active context. Use it
|
||||||
|
on methods which depend on state created during a compilation phase, but
|
||||||
|
are intended to be called from normal Python code.
|
||||||
|
"""
|
||||||
|
def wrap(cls, ctx, *args, **kwargs):
|
||||||
|
inst = ctx.ptx.instances[cls]
|
||||||
|
func(inst, ctx, *args, **kwargs)
|
||||||
|
return classmethod(wrap)
|
||||||
|
|
||||||
class PTXEntryPoint(PTXFragment):
|
class PTXEntryPoint(PTXFragment):
|
||||||
# Human-readable entry point name
|
# Human-readable entry point name
|
||||||
name = ""
|
name = ""
|
||||||
@ -591,6 +605,7 @@ class PTXEntryPoint(PTXFragment):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@instmethod
|
||||||
def call(self, ctx):
|
def call(self, ctx):
|
||||||
"""
|
"""
|
||||||
Calls the entry point on the device. Haven't worked out the details
|
Calls the entry point on the device. Haven't worked out the details
|
||||||
@ -819,7 +834,6 @@ class PTXModule(object):
|
|||||||
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
|
print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in
|
||||||
enumerate(self.source.split('\n'))])
|
enumerate(self.source.split('\n'))])
|
||||||
|
|
||||||
|
|
||||||
def _flatten(val):
|
def _flatten(val):
|
||||||
if isinstance(val, (list, tuple)):
|
if isinstance(val, (list, tuple)):
|
||||||
return ''.join(map(_flatten, val))
|
return ''.join(map(_flatten, val))
|
||||||
@ -883,7 +897,7 @@ class DataStream(PTXFragment):
|
|||||||
>>> class ExampleDataStream(DataStream):
|
>>> class ExampleDataStream(DataStream):
|
||||||
>>> shortname = "ex"
|
>>> shortname = "ex"
|
||||||
|
|
||||||
Inside DSL functions, you can "retrieve" arbitrary Python expressions from
|
Inside DSL functions, you can retrieve arbitrary Python expressions from
|
||||||
the data stream.
|
the data stream.
|
||||||
|
|
||||||
>>> @ptx_func
|
>>> @ptx_func
|
||||||
@ -892,22 +906,17 @@ class DataStream(PTXFragment):
|
|||||||
>>> op.mov.u32(regA, some_device_allocation_base_address)
|
>>> op.mov.u32(regA, some_device_allocation_base_address)
|
||||||
>>> # From the structure at the base address in 'regA', load the value
|
>>> # From the structure at the base address in 'regA', load the value
|
||||||
>>> # of 'ctx.nthreads' into reg1
|
>>> # of 'ctx.nthreads' into reg1
|
||||||
>>> ex.get(regA, reg1, 'ctx.nthreads')
|
>>> ex.get(regA, reg1, 'ctx.nthreads+padding')
|
||||||
|
|
||||||
The expressions will be stored as strings and mapped to particular
|
The expressions will be stored as strings and mapped to particular
|
||||||
positions in the struct. Later, the expressions will be evaluated and
|
positions in the struct. Later, the expressions will be evaluated and
|
||||||
coerced into a type matching the destination register:
|
coerced into a type matching the destination register:
|
||||||
|
|
||||||
>>> # Fish the instance holding the data stream from the compiled module
|
>>> data = ExampleDataStream.pack(ctx, padding=4)
|
||||||
>>> ex_stream = launch_context.ptx.instances[ExampleDataStream]
|
|
||||||
>>> # Evaluate the expressions in the current namespace, augmented with the
|
|
||||||
>>> # supplied objects
|
|
||||||
>>> data = ex_stream.pack(ctx=launch_context)
|
|
||||||
|
|
||||||
Expressions will be aligned and may be reused in such a way as to minimize
|
Expressions will be aligned and may be reused in such a way as to minimize
|
||||||
access times when taking device caching into account. This also implies
|
access times when taking device caching into account. This also implies
|
||||||
that the evaluated expressions should not modify any state, but that should
|
that the evaluated expressions should not modify any state.
|
||||||
be obvious, no?
|
|
||||||
|
|
||||||
>>> @ptx_func
|
>>> @ptx_func
|
||||||
>>> def example_func_2():
|
>>> def example_func_2():
|
||||||
@ -1034,7 +1043,8 @@ class DataStream(PTXFragment):
|
|||||||
for dv in self.size_delayvars:
|
for dv in self.size_delayvars:
|
||||||
dv.val = self._size
|
dv.val = self._size
|
||||||
|
|
||||||
def pack(self, _out_file_ = None, **kwargs):
|
@instmethod
|
||||||
|
def pack(self, ctx, _out_file_ = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluates all statements in the context of **kwargs. Take this code,
|
Evaluates all statements in the context of **kwargs. Take this code,
|
||||||
presumably inside a PTX func::
|
presumably inside a PTX func::
|
||||||
@ -1043,25 +1053,31 @@ class DataStream(PTXFragment):
|
|||||||
|
|
||||||
To pack this into a struct, call this method on an instance:
|
To pack this into a struct, call this method on an instance:
|
||||||
|
|
||||||
>>> ex_stream = launch_context.ptx.instances[ExampleDataStream]
|
>>> data = ExampleDataStream.pack(ctx, frob=4, xyz=xyz)
|
||||||
>>> data = ex_stream.pack(frob=4, xyz=xyz)
|
|
||||||
|
|
||||||
This evaluates each Python expression from the stream with the provided
|
This evaluates each Python expression from the stream with the provided
|
||||||
arguments as locals, coerces it to the appropriate type, and returns
|
arguments as locals, coerces it to the appropriate type, and returns
|
||||||
the resulting structure as a string.
|
the resulting structure as a string.
|
||||||
|
|
||||||
|
The supplied LaunchContext is added to the namespace as ``ctx`` by
|
||||||
|
default. To supress, this, override ``ctx`` in the keyword arguments:
|
||||||
|
|
||||||
|
>>> data = ExampleDataStream.pack(ctx, frob=5, xyz=xyz, ctx=None)
|
||||||
"""
|
"""
|
||||||
out = StringIO()
|
out = StringIO()
|
||||||
self.pack_into(out, kwargs)
|
cls.pack_into(out, kwargs)
|
||||||
return out.read()
|
return out.read()
|
||||||
|
|
||||||
def pack_into(self, outfile, **kwargs):
|
@instmethod
|
||||||
|
def pack_into(self, ctx, outfile, **kwargs):
|
||||||
"""
|
"""
|
||||||
Like pack(), but write data to a file-like object at the file's current
|
Like pack(), but write data to a file-like object at the file's current
|
||||||
offset instead of returning it as a string.
|
offset instead of returning it as a string.
|
||||||
|
|
||||||
>>> ex_stream.pack_into(strio_inst, frob=4, xyz=thing)
|
>>> ex_stream.pack_into(ctx, strio_inst, frob=4, xyz=thing)
|
||||||
>>> ex_stream.pack_into(strio_inst, frob=6, xyz=another_thing)
|
>>> ex_stream.pack_into(ctx, strio_inst, frob=6, xyz=another_thing)
|
||||||
"""
|
"""
|
||||||
|
kwargs.setdefault('ctx', ctx)
|
||||||
for offset, size, texp in self.cells:
|
for offset, size, texp in self.cells:
|
||||||
if texp:
|
if texp:
|
||||||
type = texp.type
|
type = texp.type
|
||||||
@ -1071,7 +1087,8 @@ class DataStream(PTXFragment):
|
|||||||
vals = []
|
vals = []
|
||||||
outfile.write(struct.pack(type, *vals))
|
outfile.write(struct.pack(type, *vals))
|
||||||
|
|
||||||
def print_record(self):
|
@instmethod
|
||||||
|
def print_record(self, ctx):
|
||||||
for cell in self.cells:
|
for cell in self.cells:
|
||||||
if cell.texp is None:
|
if cell.texp is None:
|
||||||
print '%3d %2d --' % (cell.offset, cell.size)
|
print '%3d %2d --' % (cell.offset, cell.size)
|
||||||
|
@ -35,10 +35,9 @@ class Frame(pyflam3.Frame):
|
|||||||
"Distribution of a CP across multiple CTAs not yet done")
|
"Distribution of a CP across multiple CTAs not yet done")
|
||||||
# Interpolate each time step, calculate per-step variables, and pack
|
# Interpolate each time step, calculate per-step variables, and pack
|
||||||
# into the stream
|
# into the stream
|
||||||
cp_streamer = ctx.ptx.instances[CPDataStream]
|
|
||||||
stream = StringIO()
|
stream = StringIO()
|
||||||
print "Data stream contents:"
|
print "Data stream contents:"
|
||||||
cp_streamer.print_record()
|
CPDataStream.print_record(ctx)
|
||||||
tcp = BaseGenome()
|
tcp = BaseGenome()
|
||||||
for batch_idx in range(cp.nbatches):
|
for batch_idx in range(cp.nbatches):
|
||||||
for time_idx in range(cp.ntemporal_samples):
|
for time_idx in range(cp.ntemporal_samples):
|
||||||
@ -51,10 +50,8 @@ class Frame(pyflam3.Frame):
|
|||||||
cp.width * cp.height) / (
|
cp.width * cp.height) / (
|
||||||
cp.nbatches * cp.ntemporal_samples)
|
cp.nbatches * cp.ntemporal_samples)
|
||||||
|
|
||||||
cp_streamer.pack_into(stream,
|
CPDataStream.pack_into(ctx, stream,
|
||||||
frame=self,
|
frame=self, cp=tcp, cp_idx=idx)
|
||||||
cp=tcp,
|
|
||||||
cp_idx=idx)
|
|
||||||
stream.seek(0)
|
stream.seek(0)
|
||||||
return (stream.read(), cp.nbatches * cp.ntemporal_samples)
|
return (stream.read(), cp.nbatches * cp.ntemporal_samples)
|
||||||
|
|
||||||
@ -108,8 +105,8 @@ class Animation(object):
|
|||||||
# TODO: allow animation-long override of certain parameters (size, etc)
|
# TODO: allow animation-long override of certain parameters (size, etc)
|
||||||
cp_stream, num_cps = self.frame.pack_stream(self.ctx, time)
|
cp_stream, num_cps = self.frame.pack_stream(self.ctx, time)
|
||||||
iter_thread = self.ctx.ptx.instances[IterThread]
|
iter_thread = self.ctx.ptx.instances[IterThread]
|
||||||
iter_thread.upload_cp_stream(self.ctx, cp_stream, num_cps)
|
IterThread.upload_cp_stream(self.ctx, cp_stream, num_cps)
|
||||||
iter_thread.call(self.ctx)
|
IterThread.call(self.ctx)
|
||||||
|
|
||||||
class Features(object):
|
class Features(object):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user