Refactor call() to be more elegant

This commit is contained in:
Steven Robertson 2010-09-10 14:43:20 -04:00
parent fb4e5b75e9
commit 4552589b35
5 changed files with 106 additions and 88 deletions

View File

@ -9,7 +9,7 @@ import pycuda.autoinit
import pycuda.driver as cuda
from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func, instmethod
from cuburnlib.cuda import LaunchContext
from cuburnlib.device_code import MWCRNG
from cuburnlib.device_code import MWCRNG, MWCRNGTest
class L2WriteCombining(PTXTest):
"""
@ -104,26 +104,18 @@ class L2WriteCombining(PTXTest):
op.setp.ge.u32(p_done, x, 2)
op.bra.uni(l2_restart, ifnotp=p_done)
@instmethod
def call(self, ctx):
scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
func = ctx.mod.get_function(self.entry_name)
dtime = func(cuda.InOut(times_bytes), cuda.InOut(scratch),
block=ctx.block, grid=ctx.grid, time_kernel=True)
def _call(self, ctx, func):
self.scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
self.times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
super(L2WriteCombining, self)._call(ctx, func,
cuda.InOut(self.scratch), cuda.InOut(self.times_bytes))
#printover(times_bytes[0], 6, 32)
#printover(times_bytes[1], 6)
#printover(times_bytes[2], 6, 32)
#printover(times_bytes[3], 6)
#printover(scratch[i:i+16], 8)
print "\nTotal time was %g seconds" % dtime
def call_teardown(self, ctx):
pm = lambda a: (np.mean(a), np.std(a) / np.sqrt(len(a)))
print "Clks for coa was %g ± %g" % pm(times_bytes[0])
print "Bytes for coa was %g ± %g" % pm(times_bytes[1])
print "Clks for uncoa was %g ± %g" % pm(times_bytes[2])
print "Bytes for uncoa was %g ± %g" % pm(times_bytes[3])
print "Clks for coa was %g ± %g" % pm(self.times_bytes[0])
print "Bytes for coa was %g ± %g" % pm(self.times_bytes[1])
print "Clks for uncoa was %g ± %g" % pm(self.times_bytes[2])
print "Bytes for uncoa was %g ± %g" % pm(self.times_bytes[3])
print ''
def printover(a, r, s=1):
@ -134,9 +126,10 @@ def printover(a, r, s=1):
def main():
# TODO: block/grid auto-optimization
ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1),
tests=True)
ctx = LaunchContext([L2WriteCombining, MWCRNGTest],
block=(128,1,1), grid=(7*8,1), tests=True)
ctx.compile(verbose=3)
ctx.run_tests()
L2WriteCombining.call(ctx)
if __name__ == "__main__":

View File

@ -10,7 +10,7 @@ import pycuda.gl.autoinit
import numpy as np
from cuburnlib.ptx import PTXModule
from cuburnlib.ptx import PTXModule, PTXTest, PTXTestFailure
class LaunchContext(object):
"""
@ -72,29 +72,34 @@ class LaunchContext(object):
entry.entry_name, func.num_regs,
func.shared_size_bytes, func.local_size_bytes)
def set_up(self):
for inst in self.ptx.deporder(self.ptx.instances.values(),
self.ptx.instances):
inst.device_init(self)
def call_setup(self, entry_inst):
for inst in self.ptx.entry_deps[type(entry_inst)]:
inst.call_setup(self)
def run(self):
if not self.setup_done: self.set_up()
def run_test(self, test_type):
if not self.setup_done: self.set_up()
inst = self.ptx.instances[test_type]
print "Running test: %s... " % inst.name
def call_teardown(self, entry_inst):
okay = True
for inst in reversed(self.ptx.entry_deps[type(entry_inst)]):
if inst is entry_inst and isinstance(entry_inst, PTXTest):
try:
cuda.Context.synchronize()
if inst.call(self):
print "Test %s passed." % inst.name
inst.call_teardown(self)
except PTXTestFailure, e:
print "PTX Test %s failed!" % inst.entry_name, e
okay = False
else:
print "Test %s FAILED." % inst.name
except Exception, e:
print "Test %s FAILED (exception thrown)." % inst.name
raise e
inst.call_teardown(self)
return okay
def run_tests(self):
map(self.run_test, self.ptx.tests)
if not self.ptx.tests:
print "No tests to run."
return True
all_okay = True
for test in self.ptx.tests:
cuda.Context.synchronize()
if test.call(self):
print "Test %s passed." % test.entry_name
else:
print "Test %s FAILED." % test.entry_name
all_okay = False
return all_okay

View File

@ -148,18 +148,19 @@ class IterThread(PTXEntryPoint):
CPDataStream.print_record(ctx, cp_stream, 5)
self.cps_uploaded = True
@instmethod
def call(self, ctx):
def call_setup(self, ctx):
if not self.cps_uploaded:
raise Error("Cannot call IterThread before uploading CPs")
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
cuda.memset_d32(num_cps_st_dp, 0, 1)
func = ctx.mod.get_function('iter_thread')
def _call(self, ctx, func):
# Get texture reference from the Palette
# TODO: more elegant method than reaching into ctx.ptx?
tr = ctx.ptx.instances[PaletteLookup].texref
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True,
texrefs=[tr])
super(IterThread, self)._call(ctx, func, texrefs=[tr])
def call_teardown(self, ctx):
shape = (ctx.grid[0], ctx.block[0]/32, 32)
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds')
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes')
@ -325,7 +326,7 @@ class PaletteLookup(PTXFragment):
self.texref.set_address_mode(1, cuda.address_mode.CLAMP)
self.texref.set_array(dev_array)
def device_init(self, ctx):
def call_setup(self, ctx):
assert self.texref, "Must upload palette texture before launch!"
class HistScatter(PTXFragment):
@ -368,7 +369,7 @@ class HistScatter(PTXFragment):
op.red.add.f32(addr(hist_bin_addr,12), a)
def device_init(self, ctx):
def call_setup(self, ctx):
hist_bins_dp, hist_bins_l = ctx.mod.get_global('g_hist_bins')
cuda.memset_d32(hist_bins_dp, 0, hist_bins_l/4)
@ -383,14 +384,10 @@ class MWCRNG(PTXFragment):
shortname = "mwc"
def __init__(self):
self.rand = np.random
self.threads_ready = 0
if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found')
def set_seed(self, seed):
self.rand = np.random.mtrand.RandomState(seed)
@ptx_func
def module_setup(self):
mem.global_.u32('mwc_rng_mults', ctx.threads)
@ -450,11 +447,12 @@ class MWCRNG(PTXFragment):
op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.f32(dst_reg, dst_reg, '0f30000000') # 1./(1<<31)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads:
# Already set up enough random states, don't push again
return
@instmethod
def seed(self, ctx, rand=np.random):
"""
Seed the random number generators with values taken from a
``np.random`` instance.
"""
# Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B')
@ -463,18 +461,22 @@ class MWCRNG(PTXFragment):
# Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise.
mults = np.array(mults[:ctx.threads*4])
self.rand.shuffle(mults)
rand.shuffle(mults)
# Copy multipliers and seeds to the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
cuda.memcpy_htod_async(multdp, mults.tostring()[:multl])
# Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0
states = np.array(self.rand.randint(1, 0xffffffff, size=2*ctx.threads),
states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.threads),
dtype=np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
cuda.memcpy_htod_async(statedp, states.tostring())
self.threads_ready = ctx.threads
def call_setup(self, ctx):
if self.threads_ready < ctx.threads:
self.seed(ctx)
def tests(self):
return [MWCRNGTest]
@ -515,7 +517,7 @@ class MWCRNGTest(PTXTest):
op.mad.lo.u32(adr, offset, 8, adr)
op.st.global_.u64(addr(adr), sum)
def call(self, ctx):
def call_setup(self, ctx):
# Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
@ -533,15 +535,13 @@ class MWCRNGTest(PTXTest):
ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime
func = ctx.mod.get_function('MWC_RNG_test')
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
print "Done on device, took %g seconds (%gx)" % (dtime, ctime/dtime)
def call_teardown(self, ctx):
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all():
print "State discrepancy"
print dfullstates
print fullstates
return False
raise PTXTestFailure("MWC RNG state discrepancy")
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
@ -549,11 +549,7 @@ class MWCRNGTest(PTXTest):
print "Sum discrepancy"
print dsums
print sums
return False
return True
class CameraCoordTransform(PTXFragment):
pass
raise PTXTestFailure("MWC RNG sum discrepancy")
class CPDataStream(DataStream):
"""DataStream which stores the control points."""

View File

@ -578,13 +578,23 @@ class PTXFragment(object):
"""
return []
def device_init(self, ctx):
def call_setup(self, ctx):
"""
Do stuff on the host to prepare the device for execution. 'ctx' is a
LaunchContext or similar. This will get called (in dependency order, of
course) *either* before any entry point invocation, or before *each*
invocation, I'm not sure which yet. (For now it's "each".)
course) before each function invocation.
"""
# I haven't found a good way to get outside context in for this method.
# As a result, this is usually just a check to see if some other
# necessary method has been called before trying to launch.
pass
def call_teardown(self, ctx):
"""
As with ``call_setup``, but after a call and in reverse order.
"""
# Exceptions raised here will propagate from the invocation in Python,
# so this is a good place to do error checking.
pass
def instmethod(func):
@ -599,8 +609,6 @@ def instmethod(func):
return classmethod(wrap)
class PTXEntryPoint(PTXFragment):
# Human-readable entry point name
name = ""
# Device code entry name
entry_name = ""
# List of (type, name) pairs for entry params, e.g. [('u32', 'thing')]
@ -615,28 +623,44 @@ class PTXEntryPoint(PTXFragment):
"""
raise NotImplementedError
def _call(self, ctx, func, *args, **kwargs):
"""
Override this if you need to change how a function is called.
"""
# TODO: global debugging / verbosity
print "Invoking PTX function '%s' on device" % self.entry_name
kwargs.setdefault('block', ctx.block)
kwargs.setdefault('grid', ctx.grid)
dtime = func(time_kernel=True, *args, **kwargs)
print "'%s' completed in %gs" % (self.entry_name, dtime)
@instmethod
def call(self, ctx):
def call(self, ctx, *args, **kwargs):
"""
Calls the entry point on the device. Haven't worked out the details
of this one yet.
Calls the entry point on the device, performing any setup and teardown
needed.
"""
pass
ctx.call_setup(self)
func = ctx.mod.get_function(self.entry_name)
self._call(ctx, func, *args, **kwargs)
return ctx.call_teardown(self)
class PTXTestFailure(Exception): pass
class PTXTest(PTXEntryPoint):
"""PTXTests are semantically equivalent to PTXEntryPoints, but they
differ slightly in use. In particular:
"""PTXTests are semantically equivalent to PTXEntryPoints, but they differ
slightly in the way they are invoked:
* The "name" property should describe the test being performed,
* ctx.stream will be synchronized before 'call' is run, and should be
synchronized afterwards (i.e. sync it yourself or don't use it),
* call() should return True to indicate that a test passed, or
False (or raise an exception) if it failed.
* The active context will be synchronized before each call,
* call_teardown() should raise ``PTXTestFailure`` if a test failed.
This exception will be caught and cleanup will be completed
(unless another exception is raised).
"""
pass
class _PTXStdLib(PTXFragment):
shortname = "std"
def __init__(self, block):
# Only module that gets the privilege of seeing 'block' directly.
self.block = block
@ -728,6 +752,7 @@ class PTXModule(object):
insts, tests, all_deps, entry_deps = (
self.deptrace(block, entries, build_tests))
self.instances = insts
self.entry_deps = entry_deps
self.tests = tests
inject = dict(inject)

View File

@ -130,7 +130,6 @@ class Animation(object):
# TODO: allow animation-long override of certain parameters (size, etc)
frame = Frame(self._frame, time)
frame.upload_data(self.ctx, self.filters, time)
self.ctx.set_up()
IterThread.call(self.ctx)
return HistScatter.get_bins(self.ctx, self.features)