Refactor call() to be more elegant

This commit is contained in:
Steven Robertson 2010-09-10 14:43:20 -04:00
parent fb4e5b75e9
commit 4552589b35
5 changed files with 106 additions and 88 deletions

View File

@ -9,7 +9,7 @@ import pycuda.autoinit
import pycuda.driver as cuda import pycuda.driver as cuda
from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func, instmethod from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func, instmethod
from cuburnlib.cuda import LaunchContext from cuburnlib.cuda import LaunchContext
from cuburnlib.device_code import MWCRNG from cuburnlib.device_code import MWCRNG, MWCRNGTest
class L2WriteCombining(PTXTest): class L2WriteCombining(PTXTest):
""" """
@ -104,26 +104,18 @@ class L2WriteCombining(PTXTest):
op.setp.ge.u32(p_done, x, 2) op.setp.ge.u32(p_done, x, 2)
op.bra.uni(l2_restart, ifnotp=p_done) op.bra.uni(l2_restart, ifnotp=p_done)
@instmethod def _call(self, ctx, func):
def call(self, ctx): self.scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64) self.times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F') super(L2WriteCombining, self)._call(ctx, func,
func = ctx.mod.get_function(self.entry_name) cuda.InOut(self.scratch), cuda.InOut(self.times_bytes))
dtime = func(cuda.InOut(times_bytes), cuda.InOut(scratch),
block=ctx.block, grid=ctx.grid, time_kernel=True)
#printover(times_bytes[0], 6, 32) def call_teardown(self, ctx):
#printover(times_bytes[1], 6)
#printover(times_bytes[2], 6, 32)
#printover(times_bytes[3], 6)
#printover(scratch[i:i+16], 8)
print "\nTotal time was %g seconds" % dtime
pm = lambda a: (np.mean(a), np.std(a) / np.sqrt(len(a))) pm = lambda a: (np.mean(a), np.std(a) / np.sqrt(len(a)))
print "Clks for coa was %g ± %g" % pm(times_bytes[0]) print "Clks for coa was %g ± %g" % pm(self.times_bytes[0])
print "Bytes for coa was %g ± %g" % pm(times_bytes[1]) print "Bytes for coa was %g ± %g" % pm(self.times_bytes[1])
print "Clks for uncoa was %g ± %g" % pm(times_bytes[2]) print "Clks for uncoa was %g ± %g" % pm(self.times_bytes[2])
print "Bytes for uncoa was %g ± %g" % pm(times_bytes[3]) print "Bytes for uncoa was %g ± %g" % pm(self.times_bytes[3])
print '' print ''
def printover(a, r, s=1): def printover(a, r, s=1):
@ -134,9 +126,10 @@ def printover(a, r, s=1):
def main(): def main():
# TODO: block/grid auto-optimization # TODO: block/grid auto-optimization
ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1), ctx = LaunchContext([L2WriteCombining, MWCRNGTest],
tests=True) block=(128,1,1), grid=(7*8,1), tests=True)
ctx.compile(verbose=3) ctx.compile(verbose=3)
ctx.run_tests()
L2WriteCombining.call(ctx) L2WriteCombining.call(ctx)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -10,7 +10,7 @@ import pycuda.gl.autoinit
import numpy as np import numpy as np
from cuburnlib.ptx import PTXModule from cuburnlib.ptx import PTXModule, PTXTest, PTXTestFailure
class LaunchContext(object): class LaunchContext(object):
""" """
@ -72,29 +72,34 @@ class LaunchContext(object):
entry.entry_name, func.num_regs, entry.entry_name, func.num_regs,
func.shared_size_bytes, func.local_size_bytes) func.shared_size_bytes, func.local_size_bytes)
def set_up(self): def call_setup(self, entry_inst):
for inst in self.ptx.deporder(self.ptx.instances.values(), for inst in self.ptx.entry_deps[type(entry_inst)]:
self.ptx.instances): inst.call_setup(self)
inst.device_init(self)
def run(self): def call_teardown(self, entry_inst):
if not self.setup_done: self.set_up() okay = True
for inst in reversed(self.ptx.entry_deps[type(entry_inst)]):
def run_test(self, test_type): if inst is entry_inst and isinstance(entry_inst, PTXTest):
if not self.setup_done: self.set_up() try:
inst = self.ptx.instances[test_type] inst.call_teardown(self)
print "Running test: %s... " % inst.name except PTXTestFailure, e:
try: print "PTX Test %s failed!" % inst.entry_name, e
cuda.Context.synchronize() okay = False
if inst.call(self):
print "Test %s passed." % inst.name
else: else:
print "Test %s FAILED." % inst.name inst.call_teardown(self)
except Exception, e: return okay
print "Test %s FAILED (exception thrown)." % inst.name
raise e
def run_tests(self): def run_tests(self):
map(self.run_test, self.ptx.tests) if not self.ptx.tests:
print "No tests to run."
return True
all_okay = True
for test in self.ptx.tests:
cuda.Context.synchronize()
if test.call(self):
print "Test %s passed." % test.entry_name
else:
print "Test %s FAILED." % test.entry_name
all_okay = False
return all_okay

View File

@ -148,18 +148,19 @@ class IterThread(PTXEntryPoint):
CPDataStream.print_record(ctx, cp_stream, 5) CPDataStream.print_record(ctx, cp_stream, 5)
self.cps_uploaded = True self.cps_uploaded = True
@instmethod def call_setup(self, ctx):
def call(self, ctx):
if not self.cps_uploaded: if not self.cps_uploaded:
raise Error("Cannot call IterThread before uploading CPs") raise Error("Cannot call IterThread before uploading CPs")
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started') num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
cuda.memset_d32(num_cps_st_dp, 0, 1) cuda.memset_d32(num_cps_st_dp, 0, 1)
func = ctx.mod.get_function('iter_thread') def _call(self, ctx, func):
# Get texture reference from the Palette
# TODO: more elegant method than reaching into ctx.ptx?
tr = ctx.ptx.instances[PaletteLookup].texref tr = ctx.ptx.instances[PaletteLookup].texref
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True, super(IterThread, self)._call(ctx, func, texrefs=[tr])
texrefs=[tr])
def call_teardown(self, ctx):
shape = (ctx.grid[0], ctx.block[0]/32, 32) shape = (ctx.grid[0], ctx.block[0]/32, 32)
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds') num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds')
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes') num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes')
@ -325,7 +326,7 @@ class PaletteLookup(PTXFragment):
self.texref.set_address_mode(1, cuda.address_mode.CLAMP) self.texref.set_address_mode(1, cuda.address_mode.CLAMP)
self.texref.set_array(dev_array) self.texref.set_array(dev_array)
def device_init(self, ctx): def call_setup(self, ctx):
assert self.texref, "Must upload palette texture before launch!" assert self.texref, "Must upload palette texture before launch!"
class HistScatter(PTXFragment): class HistScatter(PTXFragment):
@ -368,7 +369,7 @@ class HistScatter(PTXFragment):
op.red.add.f32(addr(hist_bin_addr,12), a) op.red.add.f32(addr(hist_bin_addr,12), a)
def device_init(self, ctx): def call_setup(self, ctx):
hist_bins_dp, hist_bins_l = ctx.mod.get_global('g_hist_bins') hist_bins_dp, hist_bins_l = ctx.mod.get_global('g_hist_bins')
cuda.memset_d32(hist_bins_dp, 0, hist_bins_l/4) cuda.memset_d32(hist_bins_dp, 0, hist_bins_l/4)
@ -383,14 +384,10 @@ class MWCRNG(PTXFragment):
shortname = "mwc" shortname = "mwc"
def __init__(self): def __init__(self):
self.rand = np.random
self.threads_ready = 0 self.threads_ready = 0
if not os.path.isfile('primes.bin'): if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found') raise EnvironmentError('primes.bin not found')
def set_seed(self, seed):
self.rand = np.random.mtrand.RandomState(seed)
@ptx_func @ptx_func
def module_setup(self): def module_setup(self):
mem.global_.u32('mwc_rng_mults', ctx.threads) mem.global_.u32('mwc_rng_mults', ctx.threads)
@ -450,11 +447,12 @@ class MWCRNG(PTXFragment):
op.cvt.rn.f32.s32(dst_reg, mwc_st) op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.f32(dst_reg, dst_reg, '0f30000000') # 1./(1<<31) op.mul.f32(dst_reg, dst_reg, '0f30000000') # 1./(1<<31)
def device_init(self, ctx): @instmethod
if self.threads_ready >= ctx.threads: def seed(self, ctx, rand=np.random):
# Already set up enough random states, don't push again """
return Seed the random number generators with values taken from a
``np.random`` instance.
"""
# Load raw big-endian u32 multipliers from primes.bin. # Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp: with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B') dt = np.dtype(np.uint32).newbyteorder('B')
@ -463,18 +461,22 @@ class MWCRNG(PTXFragment):
# Randomness in choosing multipliers is good, but larger multipliers # Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise. # have longer periods, which is also good. This is a compromise.
mults = np.array(mults[:ctx.threads*4]) mults = np.array(mults[:ctx.threads*4])
self.rand.shuffle(mults) rand.shuffle(mults)
# Copy multipliers and seeds to the device # Copy multipliers and seeds to the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults') multdp, multl = ctx.mod.get_global('mwc_rng_mults')
cuda.memcpy_htod_async(multdp, mults.tostring()[:multl]) cuda.memcpy_htod_async(multdp, mults.tostring()[:multl])
# Intentionally excludes both 0 and (2^32-1), as they can lead to # Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0 # degenerate sequences of period 0
states = np.array(self.rand.randint(1, 0xffffffff, size=2*ctx.threads), states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.threads),
dtype=np.uint32) dtype=np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state') statedp, statel = ctx.mod.get_global('mwc_rng_state')
cuda.memcpy_htod_async(statedp, states.tostring()) cuda.memcpy_htod_async(statedp, states.tostring())
self.threads_ready = ctx.threads self.threads_ready = ctx.threads
def call_setup(self, ctx):
if self.threads_ready < ctx.threads:
self.seed(ctx)
def tests(self): def tests(self):
return [MWCRNGTest] return [MWCRNGTest]
@ -515,7 +517,7 @@ class MWCRNGTest(PTXTest):
op.mad.lo.u32(adr, offset, 8, adr) op.mad.lo.u32(adr, offset, 8, adr)
op.st.global_.u64(addr(adr), sum) op.st.global_.u64(addr(adr), sum)
def call(self, ctx): def call_setup(self, ctx):
# Get current multipliers and seeds from the device # Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults') multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32) mults = cuda.from_device(multdp, ctx.threads, np.uint32)
@ -533,15 +535,13 @@ class MWCRNGTest(PTXTest):
ctime = time.time() - ctime ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime print "Done on host, took %g seconds" % ctime
func = ctx.mod.get_function('MWC_RNG_test') def call_teardown(self, ctx):
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
print "Done on device, took %g seconds (%gx)" % (dtime, ctime/dtime)
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64) dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all(): if not (dfullstates == fullstates).all():
print "State discrepancy" print "State discrepancy"
print dfullstates print dfullstates
print fullstates print fullstates
return False raise PTXTestFailure("MWC RNG state discrepancy")
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums') sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64) dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
@ -549,11 +549,7 @@ class MWCRNGTest(PTXTest):
print "Sum discrepancy" print "Sum discrepancy"
print dsums print dsums
print sums print sums
return False raise PTXTestFailure("MWC RNG sum discrepancy")
return True
class CameraCoordTransform(PTXFragment):
pass
class CPDataStream(DataStream): class CPDataStream(DataStream):
"""DataStream which stores the control points.""" """DataStream which stores the control points."""

View File

@ -578,13 +578,23 @@ class PTXFragment(object):
""" """
return [] return []
def device_init(self, ctx): def call_setup(self, ctx):
""" """
Do stuff on the host to prepare the device for execution. 'ctx' is a Do stuff on the host to prepare the device for execution. 'ctx' is a
LaunchContext or similar. This will get called (in dependency order, of LaunchContext or similar. This will get called (in dependency order, of
course) *either* before any entry point invocation, or before *each* course) before each function invocation.
invocation, I'm not sure which yet. (For now it's "each".)
""" """
# I haven't found a good way to get outside context in for this method.
# As a result, this is usually just a check to see if some other
# necessary method has been called before trying to launch.
pass
def call_teardown(self, ctx):
"""
As with ``call_setup``, but after a call and in reverse order.
"""
# Exceptions raised here will propagate from the invocation in Python,
# so this is a good place to do error checking.
pass pass
def instmethod(func): def instmethod(func):
@ -599,8 +609,6 @@ def instmethod(func):
return classmethod(wrap) return classmethod(wrap)
class PTXEntryPoint(PTXFragment): class PTXEntryPoint(PTXFragment):
# Human-readable entry point name
name = ""
# Device code entry name # Device code entry name
entry_name = "" entry_name = ""
# List of (type, name) pairs for entry params, e.g. [('u32', 'thing')] # List of (type, name) pairs for entry params, e.g. [('u32', 'thing')]
@ -615,28 +623,44 @@ class PTXEntryPoint(PTXFragment):
""" """
raise NotImplementedError raise NotImplementedError
def _call(self, ctx, func, *args, **kwargs):
"""
Override this if you need to change how a function is called.
"""
# TODO: global debugging / verbosity
print "Invoking PTX function '%s' on device" % self.entry_name
kwargs.setdefault('block', ctx.block)
kwargs.setdefault('grid', ctx.grid)
dtime = func(time_kernel=True, *args, **kwargs)
print "'%s' completed in %gs" % (self.entry_name, dtime)
@instmethod @instmethod
def call(self, ctx): def call(self, ctx, *args, **kwargs):
""" """
Calls the entry point on the device. Haven't worked out the details Calls the entry point on the device, performing any setup and teardown
of this one yet. needed.
""" """
pass ctx.call_setup(self)
func = ctx.mod.get_function(self.entry_name)
self._call(ctx, func, *args, **kwargs)
return ctx.call_teardown(self)
class PTXTestFailure(Exception): pass
class PTXTest(PTXEntryPoint): class PTXTest(PTXEntryPoint):
"""PTXTests are semantically equivalent to PTXEntryPoints, but they """PTXTests are semantically equivalent to PTXEntryPoints, but they differ
differ slightly in use. In particular: slightly in the way they are invoked:
* The "name" property should describe the test being performed, * The active context will be synchronized before each call,
* ctx.stream will be synchronized before 'call' is run, and should be * call_teardown() should raise ``PTXTestFailure`` if a test failed.
synchronized afterwards (i.e. sync it yourself or don't use it), This exception will be caught and cleanup will be completed
* call() should return True to indicate that a test passed, or (unless another exception is raised).
False (or raise an exception) if it failed.
""" """
pass pass
class _PTXStdLib(PTXFragment): class _PTXStdLib(PTXFragment):
shortname = "std" shortname = "std"
def __init__(self, block): def __init__(self, block):
# Only module that gets the privilege of seeing 'block' directly. # Only module that gets the privilege of seeing 'block' directly.
self.block = block self.block = block
@ -728,6 +752,7 @@ class PTXModule(object):
insts, tests, all_deps, entry_deps = ( insts, tests, all_deps, entry_deps = (
self.deptrace(block, entries, build_tests)) self.deptrace(block, entries, build_tests))
self.instances = insts self.instances = insts
self.entry_deps = entry_deps
self.tests = tests self.tests = tests
inject = dict(inject) inject = dict(inject)

View File

@ -130,7 +130,6 @@ class Animation(object):
# TODO: allow animation-long override of certain parameters (size, etc) # TODO: allow animation-long override of certain parameters (size, etc)
frame = Frame(self._frame, time) frame = Frame(self._frame, time)
frame.upload_data(self.ctx, self.filters, time) frame.upload_data(self.ctx, self.filters, time)
self.ctx.set_up()
IterThread.call(self.ctx) IterThread.call(self.ctx)
return HistScatter.get_bins(self.ctx, self.features) return HistScatter.get_bins(self.ctx, self.features)