mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Refactor call() to be more elegant
This commit is contained in:
parent
fb4e5b75e9
commit
4552589b35
35
bench.py
35
bench.py
@ -9,7 +9,7 @@ import pycuda.autoinit
|
||||
import pycuda.driver as cuda
|
||||
from cuburnlib.ptx import PTXFragment, PTXTest, ptx_func, instmethod
|
||||
from cuburnlib.cuda import LaunchContext
|
||||
from cuburnlib.device_code import MWCRNG
|
||||
from cuburnlib.device_code import MWCRNG, MWCRNGTest
|
||||
|
||||
class L2WriteCombining(PTXTest):
|
||||
"""
|
||||
@ -104,26 +104,18 @@ class L2WriteCombining(PTXTest):
|
||||
op.setp.ge.u32(p_done, x, 2)
|
||||
op.bra.uni(l2_restart, ifnotp=p_done)
|
||||
|
||||
@instmethod
|
||||
def call(self, ctx):
|
||||
scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
|
||||
times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
|
||||
func = ctx.mod.get_function(self.entry_name)
|
||||
dtime = func(cuda.InOut(times_bytes), cuda.InOut(scratch),
|
||||
block=ctx.block, grid=ctx.grid, time_kernel=True)
|
||||
def _call(self, ctx, func):
|
||||
self.scratch = np.zeros(self.block_size*ctx.ctas/4, np.uint64)
|
||||
self.times_bytes = np.zeros((4, ctx.threads), np.uint64, 'F')
|
||||
super(L2WriteCombining, self)._call(ctx, func,
|
||||
cuda.InOut(self.scratch), cuda.InOut(self.times_bytes))
|
||||
|
||||
#printover(times_bytes[0], 6, 32)
|
||||
#printover(times_bytes[1], 6)
|
||||
#printover(times_bytes[2], 6, 32)
|
||||
#printover(times_bytes[3], 6)
|
||||
#printover(scratch[i:i+16], 8)
|
||||
|
||||
print "\nTotal time was %g seconds" % dtime
|
||||
def call_teardown(self, ctx):
|
||||
pm = lambda a: (np.mean(a), np.std(a) / np.sqrt(len(a)))
|
||||
print "Clks for coa was %g ± %g" % pm(times_bytes[0])
|
||||
print "Bytes for coa was %g ± %g" % pm(times_bytes[1])
|
||||
print "Clks for uncoa was %g ± %g" % pm(times_bytes[2])
|
||||
print "Bytes for uncoa was %g ± %g" % pm(times_bytes[3])
|
||||
print "Clks for coa was %g ± %g" % pm(self.times_bytes[0])
|
||||
print "Bytes for coa was %g ± %g" % pm(self.times_bytes[1])
|
||||
print "Clks for uncoa was %g ± %g" % pm(self.times_bytes[2])
|
||||
print "Bytes for uncoa was %g ± %g" % pm(self.times_bytes[3])
|
||||
print ''
|
||||
|
||||
def printover(a, r, s=1):
|
||||
@ -134,9 +126,10 @@ def printover(a, r, s=1):
|
||||
|
||||
def main():
|
||||
# TODO: block/grid auto-optimization
|
||||
ctx = LaunchContext([L2WriteCombining], block=(128,1,1), grid=(7*8,1),
|
||||
tests=True)
|
||||
ctx = LaunchContext([L2WriteCombining, MWCRNGTest],
|
||||
block=(128,1,1), grid=(7*8,1), tests=True)
|
||||
ctx.compile(verbose=3)
|
||||
ctx.run_tests()
|
||||
L2WriteCombining.call(ctx)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -10,7 +10,7 @@ import pycuda.gl.autoinit
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cuburnlib.ptx import PTXModule
|
||||
from cuburnlib.ptx import PTXModule, PTXTest, PTXTestFailure
|
||||
|
||||
class LaunchContext(object):
|
||||
"""
|
||||
@ -72,29 +72,34 @@ class LaunchContext(object):
|
||||
entry.entry_name, func.num_regs,
|
||||
func.shared_size_bytes, func.local_size_bytes)
|
||||
|
||||
def set_up(self):
|
||||
for inst in self.ptx.deporder(self.ptx.instances.values(),
|
||||
self.ptx.instances):
|
||||
inst.device_init(self)
|
||||
def call_setup(self, entry_inst):
|
||||
for inst in self.ptx.entry_deps[type(entry_inst)]:
|
||||
inst.call_setup(self)
|
||||
|
||||
def run(self):
|
||||
if not self.setup_done: self.set_up()
|
||||
|
||||
def run_test(self, test_type):
|
||||
if not self.setup_done: self.set_up()
|
||||
inst = self.ptx.instances[test_type]
|
||||
print "Running test: %s... " % inst.name
|
||||
try:
|
||||
cuda.Context.synchronize()
|
||||
if inst.call(self):
|
||||
print "Test %s passed." % inst.name
|
||||
def call_teardown(self, entry_inst):
|
||||
okay = True
|
||||
for inst in reversed(self.ptx.entry_deps[type(entry_inst)]):
|
||||
if inst is entry_inst and isinstance(entry_inst, PTXTest):
|
||||
try:
|
||||
inst.call_teardown(self)
|
||||
except PTXTestFailure, e:
|
||||
print "PTX Test %s failed!" % inst.entry_name, e
|
||||
okay = False
|
||||
else:
|
||||
print "Test %s FAILED." % inst.name
|
||||
except Exception, e:
|
||||
print "Test %s FAILED (exception thrown)." % inst.name
|
||||
raise e
|
||||
inst.call_teardown(self)
|
||||
return okay
|
||||
|
||||
def run_tests(self):
|
||||
map(self.run_test, self.ptx.tests)
|
||||
|
||||
if not self.ptx.tests:
|
||||
print "No tests to run."
|
||||
return True
|
||||
all_okay = True
|
||||
for test in self.ptx.tests:
|
||||
cuda.Context.synchronize()
|
||||
if test.call(self):
|
||||
print "Test %s passed." % test.entry_name
|
||||
else:
|
||||
print "Test %s FAILED." % test.entry_name
|
||||
all_okay = False
|
||||
return all_okay
|
||||
|
||||
|
@ -148,18 +148,19 @@ class IterThread(PTXEntryPoint):
|
||||
CPDataStream.print_record(ctx, cp_stream, 5)
|
||||
self.cps_uploaded = True
|
||||
|
||||
@instmethod
|
||||
def call(self, ctx):
|
||||
def call_setup(self, ctx):
|
||||
if not self.cps_uploaded:
|
||||
raise Error("Cannot call IterThread before uploading CPs")
|
||||
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
|
||||
cuda.memset_d32(num_cps_st_dp, 0, 1)
|
||||
|
||||
func = ctx.mod.get_function('iter_thread')
|
||||
def _call(self, ctx, func):
|
||||
# Get texture reference from the Palette
|
||||
# TODO: more elegant method than reaching into ctx.ptx?
|
||||
tr = ctx.ptx.instances[PaletteLookup].texref
|
||||
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True,
|
||||
texrefs=[tr])
|
||||
super(IterThread, self)._call(ctx, func, texrefs=[tr])
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
shape = (ctx.grid[0], ctx.block[0]/32, 32)
|
||||
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds')
|
||||
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes')
|
||||
@ -325,7 +326,7 @@ class PaletteLookup(PTXFragment):
|
||||
self.texref.set_address_mode(1, cuda.address_mode.CLAMP)
|
||||
self.texref.set_array(dev_array)
|
||||
|
||||
def device_init(self, ctx):
|
||||
def call_setup(self, ctx):
|
||||
assert self.texref, "Must upload palette texture before launch!"
|
||||
|
||||
class HistScatter(PTXFragment):
|
||||
@ -368,7 +369,7 @@ class HistScatter(PTXFragment):
|
||||
op.red.add.f32(addr(hist_bin_addr,12), a)
|
||||
|
||||
|
||||
def device_init(self, ctx):
|
||||
def call_setup(self, ctx):
|
||||
hist_bins_dp, hist_bins_l = ctx.mod.get_global('g_hist_bins')
|
||||
cuda.memset_d32(hist_bins_dp, 0, hist_bins_l/4)
|
||||
|
||||
@ -383,14 +384,10 @@ class MWCRNG(PTXFragment):
|
||||
shortname = "mwc"
|
||||
|
||||
def __init__(self):
|
||||
self.rand = np.random
|
||||
self.threads_ready = 0
|
||||
if not os.path.isfile('primes.bin'):
|
||||
raise EnvironmentError('primes.bin not found')
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.rand = np.random.mtrand.RandomState(seed)
|
||||
|
||||
@ptx_func
|
||||
def module_setup(self):
|
||||
mem.global_.u32('mwc_rng_mults', ctx.threads)
|
||||
@ -450,11 +447,12 @@ class MWCRNG(PTXFragment):
|
||||
op.cvt.rn.f32.s32(dst_reg, mwc_st)
|
||||
op.mul.f32(dst_reg, dst_reg, '0f30000000') # 1./(1<<31)
|
||||
|
||||
def device_init(self, ctx):
|
||||
if self.threads_ready >= ctx.threads:
|
||||
# Already set up enough random states, don't push again
|
||||
return
|
||||
|
||||
@instmethod
|
||||
def seed(self, ctx, rand=np.random):
|
||||
"""
|
||||
Seed the random number generators with values taken from a
|
||||
``np.random`` instance.
|
||||
"""
|
||||
# Load raw big-endian u32 multipliers from primes.bin.
|
||||
with open('primes.bin') as primefp:
|
||||
dt = np.dtype(np.uint32).newbyteorder('B')
|
||||
@ -463,18 +461,22 @@ class MWCRNG(PTXFragment):
|
||||
# Randomness in choosing multipliers is good, but larger multipliers
|
||||
# have longer periods, which is also good. This is a compromise.
|
||||
mults = np.array(mults[:ctx.threads*4])
|
||||
self.rand.shuffle(mults)
|
||||
rand.shuffle(mults)
|
||||
# Copy multipliers and seeds to the device
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
cuda.memcpy_htod_async(multdp, mults.tostring()[:multl])
|
||||
# Intentionally excludes both 0 and (2^32-1), as they can lead to
|
||||
# degenerate sequences of period 0
|
||||
states = np.array(self.rand.randint(1, 0xffffffff, size=2*ctx.threads),
|
||||
states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.threads),
|
||||
dtype=np.uint32)
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
cuda.memcpy_htod_async(statedp, states.tostring())
|
||||
self.threads_ready = ctx.threads
|
||||
|
||||
def call_setup(self, ctx):
|
||||
if self.threads_ready < ctx.threads:
|
||||
self.seed(ctx)
|
||||
|
||||
def tests(self):
|
||||
return [MWCRNGTest]
|
||||
|
||||
@ -515,7 +517,7 @@ class MWCRNGTest(PTXTest):
|
||||
op.mad.lo.u32(adr, offset, 8, adr)
|
||||
op.st.global_.u64(addr(adr), sum)
|
||||
|
||||
def call(self, ctx):
|
||||
def call_setup(self, ctx):
|
||||
# Get current multipliers and seeds from the device
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
|
||||
@ -533,15 +535,13 @@ class MWCRNGTest(PTXTest):
|
||||
ctime = time.time() - ctime
|
||||
print "Done on host, took %g seconds" % ctime
|
||||
|
||||
func = ctx.mod.get_function('MWC_RNG_test')
|
||||
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
|
||||
print "Done on device, took %g seconds (%gx)" % (dtime, ctime/dtime)
|
||||
def call_teardown(self, ctx):
|
||||
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
|
||||
if not (dfullstates == fullstates).all():
|
||||
print "State discrepancy"
|
||||
print dfullstates
|
||||
print fullstates
|
||||
return False
|
||||
raise PTXTestFailure("MWC RNG state discrepancy")
|
||||
|
||||
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
|
||||
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
|
||||
@ -549,11 +549,7 @@ class MWCRNGTest(PTXTest):
|
||||
print "Sum discrepancy"
|
||||
print dsums
|
||||
print sums
|
||||
return False
|
||||
return True
|
||||
|
||||
class CameraCoordTransform(PTXFragment):
|
||||
pass
|
||||
raise PTXTestFailure("MWC RNG sum discrepancy")
|
||||
|
||||
class CPDataStream(DataStream):
|
||||
"""DataStream which stores the control points."""
|
||||
|
@ -578,13 +578,23 @@ class PTXFragment(object):
|
||||
"""
|
||||
return []
|
||||
|
||||
def device_init(self, ctx):
|
||||
def call_setup(self, ctx):
|
||||
"""
|
||||
Do stuff on the host to prepare the device for execution. 'ctx' is a
|
||||
LaunchContext or similar. This will get called (in dependency order, of
|
||||
course) *either* before any entry point invocation, or before *each*
|
||||
invocation, I'm not sure which yet. (For now it's "each".)
|
||||
course) before each function invocation.
|
||||
"""
|
||||
# I haven't found a good way to get outside context in for this method.
|
||||
# As a result, this is usually just a check to see if some other
|
||||
# necessary method has been called before trying to launch.
|
||||
pass
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
"""
|
||||
As with ``call_setup``, but after a call and in reverse order.
|
||||
"""
|
||||
# Exceptions raised here will propagate from the invocation in Python,
|
||||
# so this is a good place to do error checking.
|
||||
pass
|
||||
|
||||
def instmethod(func):
|
||||
@ -599,8 +609,6 @@ def instmethod(func):
|
||||
return classmethod(wrap)
|
||||
|
||||
class PTXEntryPoint(PTXFragment):
|
||||
# Human-readable entry point name
|
||||
name = ""
|
||||
# Device code entry name
|
||||
entry_name = ""
|
||||
# List of (type, name) pairs for entry params, e.g. [('u32', 'thing')]
|
||||
@ -615,28 +623,44 @@ class PTXEntryPoint(PTXFragment):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _call(self, ctx, func, *args, **kwargs):
|
||||
"""
|
||||
Override this if you need to change how a function is called.
|
||||
"""
|
||||
# TODO: global debugging / verbosity
|
||||
print "Invoking PTX function '%s' on device" % self.entry_name
|
||||
kwargs.setdefault('block', ctx.block)
|
||||
kwargs.setdefault('grid', ctx.grid)
|
||||
dtime = func(time_kernel=True, *args, **kwargs)
|
||||
print "'%s' completed in %gs" % (self.entry_name, dtime)
|
||||
|
||||
@instmethod
|
||||
def call(self, ctx):
|
||||
def call(self, ctx, *args, **kwargs):
|
||||
"""
|
||||
Calls the entry point on the device. Haven't worked out the details
|
||||
of this one yet.
|
||||
Calls the entry point on the device, performing any setup and teardown
|
||||
needed.
|
||||
"""
|
||||
pass
|
||||
ctx.call_setup(self)
|
||||
func = ctx.mod.get_function(self.entry_name)
|
||||
self._call(ctx, func, *args, **kwargs)
|
||||
return ctx.call_teardown(self)
|
||||
|
||||
class PTXTestFailure(Exception): pass
|
||||
|
||||
class PTXTest(PTXEntryPoint):
|
||||
"""PTXTests are semantically equivalent to PTXEntryPoints, but they
|
||||
differ slightly in use. In particular:
|
||||
"""PTXTests are semantically equivalent to PTXEntryPoints, but they differ
|
||||
slightly in the way they are invoked:
|
||||
|
||||
* The "name" property should describe the test being performed,
|
||||
* ctx.stream will be synchronized before 'call' is run, and should be
|
||||
synchronized afterwards (i.e. sync it yourself or don't use it),
|
||||
* call() should return True to indicate that a test passed, or
|
||||
False (or raise an exception) if it failed.
|
||||
* The active context will be synchronized before each call,
|
||||
* call_teardown() should raise ``PTXTestFailure`` if a test failed.
|
||||
This exception will be caught and cleanup will be completed
|
||||
(unless another exception is raised).
|
||||
"""
|
||||
pass
|
||||
|
||||
class _PTXStdLib(PTXFragment):
|
||||
shortname = "std"
|
||||
|
||||
def __init__(self, block):
|
||||
# Only module that gets the privilege of seeing 'block' directly.
|
||||
self.block = block
|
||||
@ -728,6 +752,7 @@ class PTXModule(object):
|
||||
insts, tests, all_deps, entry_deps = (
|
||||
self.deptrace(block, entries, build_tests))
|
||||
self.instances = insts
|
||||
self.entry_deps = entry_deps
|
||||
self.tests = tests
|
||||
|
||||
inject = dict(inject)
|
||||
|
@ -130,7 +130,6 @@ class Animation(object):
|
||||
# TODO: allow animation-long override of certain parameters (size, etc)
|
||||
frame = Frame(self._frame, time)
|
||||
frame.upload_data(self.ctx, self.filters, time)
|
||||
self.ctx.set_up()
|
||||
IterThread.call(self.ctx)
|
||||
return HistScatter.get_bins(self.ctx, self.features)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user