mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
LaunchContext.get_per_thread
This commit is contained in:
parent
3265982fec
commit
6ed8907fcb
@ -255,29 +255,20 @@ class IterThread(PTXEntryPoint):
|
||||
super(IterThread, self)._call(ctx, func, texrefs=[tr])
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
w = ctx.warps_per_cta
|
||||
shape = (ctx.grid[0], w, 32)
|
||||
|
||||
def print_thing(s, a):
|
||||
print '%s:' % s
|
||||
for i, r in enumerate(a):
|
||||
for j in range(0,len(r),w):
|
||||
for j in range(0,len(r),ctx.warps_per_cta):
|
||||
print '%2d' % i,
|
||||
for k in range(j,j+w,8):
|
||||
for k in range(j,j+ctx.warps_per_cta,8):
|
||||
print '\t' + ' '.join(
|
||||
['%8g'%np.mean(r[l]) for l in range(k,k+8)])
|
||||
|
||||
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds')
|
||||
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes')
|
||||
whatever_dp, whatever_l = ctx.mod.get_global('g_whatever')
|
||||
rounds = cuda.from_device(num_rounds_dp, shape, np.int32)
|
||||
writes = cuda.from_device(num_writes_dp, shape, np.int32)
|
||||
whatever = cuda.from_device(whatever_dp, shape, np.int32)
|
||||
rounds = ctx.get_per_thread('g_num_rounds', np.int32, shaped=True)
|
||||
writes = ctx.get_per_thread('g_num_writes', np.int32, shaped=True)
|
||||
print_thing("Rounds", rounds)
|
||||
print_thing("Writes", writes)
|
||||
#print_thing("Whatever", whatever)
|
||||
|
||||
print np.sum(rounds)
|
||||
print "Total number of rounds:", np.sum(rounds)
|
||||
|
||||
dp, l = ctx.mod.get_global('g_num_cps_started')
|
||||
cps_started = cuda.from_device(dp, 1, np.uint32)
|
||||
@ -641,13 +632,13 @@ class MWCRNG(PTXFragment):
|
||||
rand.shuffle(mults)
|
||||
# Copy multipliers and seeds to the device
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
cuda.memcpy_htod_async(multdp, mults.tostring()[:multl])
|
||||
cuda.memcpy_htod(multdp, mults.tostring()[:multl])
|
||||
# Intentionally excludes both 0 and (2^32-1), as they can lead to
|
||||
# degenerate sequences of period 0
|
||||
states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.nthreads),
|
||||
dtype=np.uint32)
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
cuda.memcpy_htod_async(statedp, states.tostring())
|
||||
cuda.memcpy_htod(statedp, states.tostring())
|
||||
self.threads_ready = ctx.nthreads
|
||||
|
||||
def call_setup(self, ctx):
|
||||
@ -696,10 +687,8 @@ class MWCRNGTest(PTXTest):
|
||||
|
||||
def call_setup(self, ctx):
|
||||
# Get current multipliers and seeds from the device
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
self.mults = cuda.from_device(multdp, ctx.nthreads, np.uint32)
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
self.fullstates = cuda.from_device(statedp, ctx.nthreads, np.uint64)
|
||||
self.mults = ctx.get_per_thread('mwc_rng_mults', np.uint32)
|
||||
self.fullstates = ctx.get_per_thread('mwc_rng_states', np.uint64)
|
||||
self.sums = np.zeros(ctx.nthreads, np.uint64)
|
||||
|
||||
print "Running %d states forward %d rounds" % \
|
||||
@ -714,18 +703,15 @@ class MWCRNGTest(PTXTest):
|
||||
print "Done on host, took %g seconds" % ctime
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
|
||||
dfullstates = cuda.from_device(statedp, ctx.nthreads, np.uint64)
|
||||
dfullstates = ctx.get_per_thread('mwc_rng_states', np.uint64)
|
||||
if not (dfullstates == self.fullstates).all():
|
||||
print "State discrepancy"
|
||||
print dfullstates
|
||||
print self.fullstates
|
||||
raise PTXTestFailure("MWC RNG state discrepancy")
|
||||
|
||||
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
|
||||
dsums = cuda.from_device(sumdp, ctx.nthreads, np.uint64)
|
||||
|
||||
dsums = ctx.get_per_thread('mwc_rng_test_sums', np.uint64)
|
||||
if not (dsums == self.sums).all():
|
||||
print "Sum discrepancy"
|
||||
print dsums
|
||||
@ -794,9 +780,8 @@ class MWCRNGFloatsTest(PTXTest):
|
||||
]
|
||||
|
||||
for fkind, rkind, exp, lim in tests:
|
||||
dp, l = ctx.mod.get_global(
|
||||
'mwc_rng_float_%s_test_%s' % (fkind, rkind))
|
||||
vals = cuda.from_device(dp, ctx.nthreads, np.float32)
|
||||
name = 'mwc_rng_float_%s_test_%s' % (fkind, rkind)
|
||||
vals = ctx.get_per_thread(name, np.float32)
|
||||
avg = np.mean(vals)
|
||||
if np.abs(avg - exp) > tol:
|
||||
raise PTXTestFailure("%s %s %g too far from %g" %
|
||||
|
@ -696,21 +696,28 @@ class _PTXStdLib(PTXFragment):
|
||||
|
||||
@ptx_func
|
||||
def store_per_thread(self, *args):
|
||||
"""Store b32 at `base+gtid*4`. Super-common debug pattern."""
|
||||
"""For each pair of arguments ``addr, val``, write ``val`` to the
|
||||
address given by ``addr+sizeof(val)*gtid``. If ``val`` is not a
|
||||
register, size will be taken from ``addr``; if ``addr`` is not a Mem
|
||||
instance, size defaults to 4."""
|
||||
with block("Per-thread storing values"):
|
||||
reg.u32('spt_base spt_offset')
|
||||
self.get_gtid(spt_offset)
|
||||
op.mul.lo.u32(spt_offset, spt_offset, 4)
|
||||
for i in range(0, len(args), 2):
|
||||
base, val = args[i], args[i+1]
|
||||
width = 4
|
||||
if isinstance(base, Mem):
|
||||
width = int(base.type[-1][-2:])/8
|
||||
if isinstance(val, Reg):
|
||||
width = int(val.type[-2:])/8
|
||||
op.mov.u32(spt_base, base)
|
||||
op.add.u32(spt_base, spt_base, spt_offset)
|
||||
op.mad.lo.u32(spt_base, spt_offset, width, spt_base)
|
||||
if isinstance(val, float):
|
||||
# Turn a constant float into the big-endian PTX binary f32
|
||||
# representation, 0fXXXXXXXX (where XX is hex byte)
|
||||
val = '0f%x%x%x%x' % reversed(map(ord,
|
||||
struct.pack('f', val)))
|
||||
op.st.b32(addr(spt_base), val)
|
||||
op._call(['st', 'b%d' % (width*4)], addr(spt_base), val)
|
||||
|
||||
@ptx_func
|
||||
def set_is_first_thread(self, p_dst):
|
||||
|
@ -87,7 +87,7 @@ class Frame(object):
|
||||
center = self.center_cp
|
||||
ncps = center.nbatches * center.ntemporal_samples
|
||||
|
||||
if ncps < ctx.ctas:
|
||||
if ncps < ctx.nctas:
|
||||
raise NotImplementedError(
|
||||
"Distribution of a CP across multiple CTAs not yet done")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user