LaunchContext.get_per_thread

This commit is contained in:
Steven Robertson 2010-09-12 13:45:55 -04:00
parent 3265982fec
commit 6ed8907fcb
3 changed files with 26 additions and 34 deletions

View File

@ -255,29 +255,20 @@ class IterThread(PTXEntryPoint):
super(IterThread, self)._call(ctx, func, texrefs=[tr]) super(IterThread, self)._call(ctx, func, texrefs=[tr])
def call_teardown(self, ctx): def call_teardown(self, ctx):
w = ctx.warps_per_cta
shape = (ctx.grid[0], w, 32)
def print_thing(s, a): def print_thing(s, a):
print '%s:' % s print '%s:' % s
for i, r in enumerate(a): for i, r in enumerate(a):
for j in range(0,len(r),w): for j in range(0,len(r),ctx.warps_per_cta):
print '%2d' % i, print '%2d' % i,
for k in range(j,j+w,8): for k in range(j,j+ctx.warps_per_cta,8):
print '\t' + ' '.join( print '\t' + ' '.join(
['%8g'%np.mean(r[l]) for l in range(k,k+8)]) ['%8g'%np.mean(r[l]) for l in range(k,k+8)])
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds') rounds = ctx.get_per_thread('g_num_rounds', np.int32, shaped=True)
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes') writes = ctx.get_per_thread('g_num_writes', np.int32, shaped=True)
whatever_dp, whatever_l = ctx.mod.get_global('g_whatever')
rounds = cuda.from_device(num_rounds_dp, shape, np.int32)
writes = cuda.from_device(num_writes_dp, shape, np.int32)
whatever = cuda.from_device(whatever_dp, shape, np.int32)
print_thing("Rounds", rounds) print_thing("Rounds", rounds)
print_thing("Writes", writes) print_thing("Writes", writes)
#print_thing("Whatever", whatever) print "Total number of rounds:", np.sum(rounds)
print np.sum(rounds)
dp, l = ctx.mod.get_global('g_num_cps_started') dp, l = ctx.mod.get_global('g_num_cps_started')
cps_started = cuda.from_device(dp, 1, np.uint32) cps_started = cuda.from_device(dp, 1, np.uint32)
@ -641,13 +632,13 @@ class MWCRNG(PTXFragment):
rand.shuffle(mults) rand.shuffle(mults)
# Copy multipliers and seeds to the device # Copy multipliers and seeds to the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults') multdp, multl = ctx.mod.get_global('mwc_rng_mults')
cuda.memcpy_htod_async(multdp, mults.tostring()[:multl]) cuda.memcpy_htod(multdp, mults.tostring()[:multl])
# Intentionally excludes both 0 and (2^32-1), as they can lead to # Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0 # degenerate sequences of period 0
states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.nthreads), states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.nthreads),
dtype=np.uint32) dtype=np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state') statedp, statel = ctx.mod.get_global('mwc_rng_state')
cuda.memcpy_htod_async(statedp, states.tostring()) cuda.memcpy_htod(statedp, states.tostring())
self.threads_ready = ctx.nthreads self.threads_ready = ctx.nthreads
def call_setup(self, ctx): def call_setup(self, ctx):
@ -696,10 +687,8 @@ class MWCRNGTest(PTXTest):
def call_setup(self, ctx): def call_setup(self, ctx):
# Get current multipliers and seeds from the device # Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults') self.mults = ctx.get_per_thread('mwc_rng_mults', np.uint32)
self.mults = cuda.from_device(multdp, ctx.nthreads, np.uint32) self.fullstates = ctx.get_per_thread('mwc_rng_states', np.uint64)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
self.fullstates = cuda.from_device(statedp, ctx.nthreads, np.uint64)
self.sums = np.zeros(ctx.nthreads, np.uint64) self.sums = np.zeros(ctx.nthreads, np.uint64)
print "Running %d states forward %d rounds" % \ print "Running %d states forward %d rounds" % \
@ -714,18 +703,15 @@ class MWCRNGTest(PTXTest):
print "Done on host, took %g seconds" % ctime print "Done on host, took %g seconds" % ctime
def call_teardown(self, ctx): def call_teardown(self, ctx):
multdp, multl = ctx.mod.get_global('mwc_rng_mults') dfullstates = ctx.get_per_thread('mwc_rng_states', np.uint64)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
dfullstates = cuda.from_device(statedp, ctx.nthreads, np.uint64)
if not (dfullstates == self.fullstates).all(): if not (dfullstates == self.fullstates).all():
print "State discrepancy" print "State discrepancy"
print dfullstates print dfullstates
print self.fullstates print self.fullstates
raise PTXTestFailure("MWC RNG state discrepancy") raise PTXTestFailure("MWC RNG state discrepancy")
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.nthreads, np.uint64) dsums = ctx.get_per_thread('mwc_rng_test_sums', np.uint64)
if not (dsums == self.sums).all(): if not (dsums == self.sums).all():
print "Sum discrepancy" print "Sum discrepancy"
print dsums print dsums
@ -794,9 +780,8 @@ class MWCRNGFloatsTest(PTXTest):
] ]
for fkind, rkind, exp, lim in tests: for fkind, rkind, exp, lim in tests:
dp, l = ctx.mod.get_global( name = 'mwc_rng_float_%s_test_%s' % (fkind, rkind)
'mwc_rng_float_%s_test_%s' % (fkind, rkind)) vals = ctx.get_per_thread(name, np.float32)
vals = cuda.from_device(dp, ctx.nthreads, np.float32)
avg = np.mean(vals) avg = np.mean(vals)
if np.abs(avg - exp) > tol: if np.abs(avg - exp) > tol:
raise PTXTestFailure("%s %s %g too far from %g" % raise PTXTestFailure("%s %s %g too far from %g" %

View File

@ -696,21 +696,28 @@ class _PTXStdLib(PTXFragment):
@ptx_func @ptx_func
def store_per_thread(self, *args): def store_per_thread(self, *args):
"""Store b32 at `base+gtid*4`. Super-common debug pattern.""" """For each pair of arguments ``addr, val``, write ``val`` to the
address given by ``addr+sizeof(val)*gtid``. If ``val`` is not a
register, size will be taken from ``addr``; if ``addr`` is not a Mem
instance, size defaults to 4."""
with block("Per-thread storing values"): with block("Per-thread storing values"):
reg.u32('spt_base spt_offset') reg.u32('spt_base spt_offset')
self.get_gtid(spt_offset) self.get_gtid(spt_offset)
op.mul.lo.u32(spt_offset, spt_offset, 4)
for i in range(0, len(args), 2): for i in range(0, len(args), 2):
base, val = args[i], args[i+1] base, val = args[i], args[i+1]
width = 4
if isinstance(base, Mem):
width = int(base.type[-1][-2:])/8
if isinstance(val, Reg):
width = int(val.type[-2:])/8
op.mov.u32(spt_base, base) op.mov.u32(spt_base, base)
op.add.u32(spt_base, spt_base, spt_offset) op.mad.lo.u32(spt_base, spt_offset, width, spt_base)
if isinstance(val, float): if isinstance(val, float):
# Turn a constant float into the big-endian PTX binary f32 # Turn a constant float into the big-endian PTX binary f32
# representation, 0fXXXXXXXX (where XX is hex byte) # representation, 0fXXXXXXXX (where XX is hex byte)
val = '0f%x%x%x%x' % reversed(map(ord, val = '0f%x%x%x%x' % reversed(map(ord,
struct.pack('f', val))) struct.pack('f', val)))
op.st.b32(addr(spt_base), val) op._call(['st', 'b%d' % (width*4)], addr(spt_base), val)
@ptx_func @ptx_func
def set_is_first_thread(self, p_dst): def set_is_first_thread(self, p_dst):

View File

@ -87,7 +87,7 @@ class Frame(object):
center = self.center_cp center = self.center_cp
ncps = center.nbatches * center.ntemporal_samples ncps = center.nbatches * center.ntemporal_samples
if ncps < ctx.ctas: if ncps < ctx.nctas:
raise NotImplementedError( raise NotImplementedError(
"Distribution of a CP across multiple CTAs not yet done") "Distribution of a CP across multiple CTAs not yet done")