Shuffle points between threads of a CTA

This commit is contained in:
Steven Robertson
2010-09-12 00:17:18 -04:00
parent 40a5ceafde
commit f368a99a16
3 changed files with 355 additions and 27 deletions

View File

@ -15,13 +15,13 @@ from cuburn.variations import Variations
class IterThread(PTXEntryPoint):
entry_name = 'iter_thread'
entry_params = []
maxnreg = 16
def __init__(self):
self.cps_uploaded = False
def deps(self):
return [MWCRNG, CPDataStream, HistScatter, Variations, Timeouter]
return [MWCRNG, CPDataStream, HistScatter, Variations, ShufflePoints,
Timeouter]
@ptx_func
def module_setup(self):
@ -48,7 +48,7 @@ class IterThread(PTXEntryPoint):
mem.shared.f32('s_xf_sel', ctx.warps_per_cta)
std.store_per_thread(g_whatever, 1234)
#std.store_per_thread(g_whatever, 1234)
# TODO: temporary, for testing
mem.local.u32('l_num_rounds')
@ -56,13 +56,11 @@ class IterThread(PTXEntryPoint):
op.st.local.u32(addr(l_num_rounds), 0)
op.st.local.u32(addr(l_num_writes), 0)
mem.local.f32('l_consec')
op.st.local.f32(addr(l_consec), 0.)
reg.f32('x_coord y_coord color_coord')
mwc.next_f32_11(x_coord)
mwc.next_f32_11(y_coord)
mwc.next_f32_01(color_coord)
reg.f32('xi xo yi yo colori coloro consec_bad')
mwc.next_f32_11(xi)
mwc.next_f32_11(yi)
mwc.next_f32_01(colori)
op.mov.f32(consec_bad, 0.)
comment("Ensure all init is done")
op.bar.sync(0)
@ -104,6 +102,7 @@ class IterThread(PTXEntryPoint):
# When fusing, num_samples holds the (negative) number of iterations
# left across the CP, rather than the number of samples in total.
with block("If still fusing, increment count unconditionally"):
op.bar.sync(0)
std.set_is_first_thread(reg.pred('p_is_first'))
op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
@ -133,8 +132,6 @@ class IterThread(PTXEntryPoint):
op.add.u32(num_rounds, num_rounds, 1)
op.st.local.u32(addr(l_num_rounds), num_rounds)
with block("Select an xform"):
reg.f32('xf_sel')
reg.u32('warp_offset xf_sel_addr')
@ -154,12 +151,9 @@ class IterThread(PTXEntryPoint):
for xf in features.xforms:
label('XFORM_%d' % xf.id)
variations.apply_xform(x_coord, y_coord, color_coord,
x_coord, y_coord, color_coord, xf.id)
variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id)
op.bra.uni("xform_done")
label("xform_done")
with block("Test if we're still in FUSE"):
reg.s32('num_samples')
@ -170,7 +164,7 @@ class IterThread(PTXEntryPoint):
reg.pred('p_point_is_valid')
with block("Write the result"):
hist.scatter(x_coord, y_coord, color_coord, 0, p_point_is_valid)
hist.scatter(xo, yo, coloro, 0, p_point_is_valid)
with block():
reg.u32('num_writes')
op.ld.local.u32(num_writes, addr(l_num_writes))
@ -180,17 +174,23 @@ class IterThread(PTXEntryPoint):
with block("If the result was invalid, handle badvals"):
reg.f32('consec')
reg.pred('need_new_point')
op.ld.local.f32(consec, addr(l_consec))
op.mov.f32(consec, 0., ifp=p_point_is_valid)
comment('If point is good, move new coords and reset consec_bad')
op.mov.f32(xi, xo, ifp=p_point_is_valid)
op.mov.f32(yi, yo, ifp=p_point_is_valid)
op.mov.f32(colori, coloro, ifp=p_point_is_valid)
op.mov.f32(consec_bad, 0., ifp=p_point_is_valid)
comment('Otherwise, add 1 to consec_bad')
op.add.f32(consec, consec, 1., ifnotp=p_point_is_valid)
op.setp.ge.f32(need_new_point, consec, 5.)
op.bra('badval_done', ifnotp=need_new_point)
mwc.next_f32_11(x_coord)
mwc.next_f32_11(y_coord)
mwc.next_f32_01(color_coord)
comment('If consec_bad > 5, pick a new random point')
mwc.next_f32_11(xi)
mwc.next_f32_11(yi)
mwc.next_f32_01(colori)
op.mov.f32(consec, 0.)
label('badval_done')
op.st.local.f32(addr(l_consec), consec)
with block("Increment number of samples by number of good values"):
reg.b32('good_samples laneid')
@ -205,11 +205,16 @@ class IterThread(PTXEntryPoint):
with block("Check to see if we're done with this CP"):
reg.pred('p_cp_done')
reg.s32('num_samples num_samples_needed')
comment('Sync before making decision to prevent divergence')
op.bar.sync(3)
op.ld.shared.s32(num_samples, addr(s_num_samples))
cp.get(cpA, num_samples_needed, 'cp.nsamples')
op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
op.bra.uni(cp_loop_start, ifp=p_cp_done)
comment('Shuffle points between threads')
shuf.shuffle(xi, yi, colori, consec_bad)
with block("If first warp, pick new thread offset"):
reg.u32('warpid')
reg.pred('first_warp')
@ -273,7 +278,7 @@ class IterThread(PTXEntryPoint):
whatever = cuda.from_device(whatever_dp, shape, np.int32)
print_thing("Rounds", rounds)
print_thing("Writes", writes)
print_thing("Whatever", whatever)
#print_thing("Whatever", whatever)
print np.sum(rounds)
@ -495,6 +500,41 @@ class HistScatter(PTXFragment):
dtype=np.float32)
class ShufflePoints(PTXFragment):
"""
Shuffle points in shared memory. See helpers/shuf.py for details.
"""
shortname = "shuf"
@ptx_func
def module_setup(self):
# TODO: if needed, merge this shared memory block with others
mem.shared.f32('s_shuf_data', ctx.threads_per_cta)
@ptx_func
def shuffle(self, *args, **kwargs):
"""
Shuffle the data from each register in args across threads. Keyword
argument ``bar`` specifies which barrier to use.
"""
bar = kwargs.pop('bar', 8)
with block("Shuffle across threads"):
reg.u32('shuf_read shuf_write')
with block("Calculate read and write offsets"):
reg.u32('shuf_off shuf_laneid')
op.mov.u32(shuf_off, '%tid.x')
op.mov.u32(shuf_write, s_shuf_data)
op.mad.lo.u32(shuf_write, shuf_off, 4, shuf_write)
op.mov.u32(shuf_laneid, '%laneid')
op.mad.lo.u32(shuf_off, shuf_laneid, 32, shuf_off)
op.and_.b32(shuf_off, shuf_off, ctx.threads_per_cta - 1)
op.mov.u32(shuf_read, s_shuf_data)
op.mad.lo.u32(shuf_read, shuf_off, 4, shuf_read)
for var in args:
op.bar.sync(bar)
op.st.shared.b32(addr(shuf_write), var)
op.bar.sync(bar)
op.ld.shared.b32(var, addr(shuf_read))
class MWCRNG(PTXFragment):
shortname = "mwc"