Experiments with larger CTAs for IterThread

This commit is contained in:
Steven Robertson 2010-09-12 02:01:03 -04:00
parent e2b1c161cf
commit c13f6a06cf
3 changed files with 36 additions and 16 deletions

View File

@ -90,7 +90,7 @@ class IterThread(PTXEntryPoint):
reg.pred('p_last_cp') reg.pred('p_last_cp')
op.ldu.u32(num_cps, addr(g_num_cps)) op.ldu.u32(num_cps, addr(g_num_cps))
op.setp.ge.u32(p_last_cp, cp_idx, num_cps) op.setp.ge.u32(p_last_cp, cp_idx, num_cps)
op.bra.uni('all_cps_done', ifp=p_last_cp) op.bra('all_cps_done', ifp=p_last_cp)
with block('Load CP address'): with block('Load CP address'):
op.mov.u32(cpA, g_cp_array) op.mov.u32(cpA, g_cp_array)
@ -149,7 +149,7 @@ class IterThread(PTXEntryPoint):
for xf in features.xforms: for xf in features.xforms:
label('XFORM_%d' % xf.id) label('XFORM_%d' % xf.id)
variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id) variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id)
op.bra.uni("xform_done") op.bra("xform_done")
label("xform_done") label("xform_done")
with block("Test if we're still in FUSE"): with block("Test if we're still in FUSE"):
@ -161,7 +161,7 @@ class IterThread(PTXEntryPoint):
reg.pred('p_point_is_valid') reg.pred('p_point_is_valid')
with block("Write the result"): with block("Write the result"):
hist.scatter(xo, yo, coloro, 0, p_point_is_valid) hist.scatter(xo, yo, coloro, 0, p_point_is_valid, 'ldst')
with block(): with block():
reg.u32('num_writes') reg.u32('num_writes')
op.ld.local.u32(num_writes, addr(l_num_writes)) op.ld.local.u32(num_writes, addr(l_num_writes))
@ -212,16 +212,15 @@ class IterThread(PTXEntryPoint):
comment('Shuffle points between threads') comment('Shuffle points between threads')
shuf.shuffle(xi, yi, colori, consec_bad) shuf.shuffle(xi, yi, colori, consec_bad)
with block("If first warp, pick new thread offset"): with block("If in first warp, pick new offset"):
reg.u32('warpid') reg.u32('tid')
reg.pred('first_warp') reg.pred('first_warp')
op.mov.u32(warpid, '%tid.x') op.mov.u32(tid, '%tid.x')
op.shr.b32(warpid, warpid, 5) assert ctx.warps_per_cta <= 32, \
op.setp.eq.u32(first_warp, warpid, 0) "Special-case for CTAs with >1024 threads not implemented"
#std.asrt("Looks like we're not the first warp", notp=first_warp, op.setp.lo.u32(first_warp, tid, 32)
#ret=True) op.bra(iter_loop_choose_xform, ifp=first_warp)
op.bra.uni(iter_loop_choose_xform, ifp=first_warp) op.bra(iter_loop_start)
op.bra.uni(iter_loop_start)
label('all_cps_done') label('all_cps_done')
# TODO this is for testing, move it to a debug statement # TODO this is for testing, move it to a debug statement
@ -258,14 +257,15 @@ class IterThread(PTXEntryPoint):
super(IterThread, self)._call(ctx, func, texrefs=[tr]) super(IterThread, self)._call(ctx, func, texrefs=[tr])
def call_teardown(self, ctx): def call_teardown(self, ctx):
shape = (ctx.grid[0], ctx.block[0]/32, 32) w = ctx.warps_per_cta
shape = (ctx.grid[0], w, 32)
def print_thing(s, a): def print_thing(s, a):
print '%s:' % s print '%s:' % s
for i, r in enumerate(a): for i, r in enumerate(a):
for j in range(0,len(r),8): for j in range(0,len(r),w):
print '%2d\t%s' % (i, print '%2d\t%s' % (i,
'\t'.join(['%g '%np.mean(r[k]) for k in range(j,j+8)])) '\t'.join(['%g '%np.mean(r[k]) for k in range(j,j+w)]))
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds') num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds')
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes') num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes')
@ -484,6 +484,11 @@ class HistScatter(PTXFragment):
op.mov.u32(hist_bin_addr, g_hist_bins) op.mov.u32(hist_bin_addr, g_hist_bins)
op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr) op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr)
if type == 'fake_notex':
op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr)
op.st.local.f32(addr(l_scatter_fake_alpha), color)
return
reg.f32('r g b a norm_time') reg.f32('r g b a norm_time')
cp.get(cpA, norm_time, 'cp.norm_time') cp.get(cpA, norm_time, 'cp.norm_time')
palette.look_up(r, g, b, a, color, norm_time) palette.look_up(r, g, b, a, color, norm_time)

View File

@ -154,7 +154,7 @@ class Animation(object):
the active device. the active device.
""" """
# TODO: user-configurable test control # TODO: user-configurable test control
self.ctx = LaunchContext([IterThread], block=(256,1,1), grid=(28,1), self.ctx = LaunchContext([IterThread], block=(512,1,1), grid=(28,1),
tests=True) tests=True)
# TODO: user-configurable verbosity control # TODO: user-configurable verbosity control
self.ctx.compile(verbose=3, anim=self, features=self.features) self.ctx.compile(verbose=3, anim=self, features=self.features)

View File

@ -280,6 +280,21 @@ def shuf_better(a):
print ' With better shuffle: %g' % monte(make(), shuf_better, 1000, 32) print ' With better shuffle: %g' % monte(make(), shuf_better, 1000, 32)
print 'For 32*16:'
t = 512
print ' With no shuffle: %g' % monte(make(), shuf_none, 1000, 32)
print ' With full shuffle: %g' % monte(make(), shuf_all, 1000, 32)
print ' With simple shuffle: %g' % monte(make(), shuf_simple, 1000, 32)
print ' With better shuffle: %g' % monte(make(), shuf_better, 1000, 32)
print 'For 32*32:'
t = 1024
print ' With no shuffle: %g' % monte(make(), shuf_none, 1000, 32)
print ' With full shuffle: %g' % monte(make(), shuf_all, 1000, 32)
print ' With simple shuffle: %g' % monte(make(), shuf_simple, 1000, 32)
print ' With better shuffle: %g' % monte(make(), shuf_better, 1000, 32)
print """ print """
Okay I actually intended this to be a blog post but I started writing before Okay I actually intended this to be a blog post but I started writing before
having done any of the math. Actually the simple shuffle looks like it's having done any of the math. Actually the simple shuffle looks like it's