Use shared memory for iter_count and have each CP processed by only one CTA.

Slower, but the code is a bit simpler conceptually, and the difference will be
more than accounted for by better scheduling towards the end of the process.
This commit is contained in:
Steven Robertson 2010-09-07 14:54:50 -04:00
parent aa065dc25d
commit 094890c324
4 changed files with 79 additions and 60 deletions

View File

@ -46,6 +46,10 @@ class LaunchContext(object):
def ctas(self):
return self.grid[0] * self.grid[1]
@property
def threads_per_cta(self):
return self.block[0] * self.block[1] * self.block[2]
def compile(self, verbose=False, **kwargs):
kwargs['ctx'] = self
self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests)

View File

@ -26,82 +26,66 @@ class IterThread(PTXTest):
mem.global_.u32('g_cp_array',
cp.stream_size*features.max_ntemporal_samples)
mem.global_.u32('g_num_cps')
mem.global_.u32('g_num_cps_started')
# TODO move into debug statement
mem.global_.u32('g_num_rounds', ctx.threads)
mem.global_.u32('g_num_writes', ctx.threads)
@ptx_func
def entry(self):
reg.f32('x_coord y_coord color_coord alpha_coord')
# For now, we indulge in the luxury of shared memory.
# Index number of current CP, shared across CTA
mem.shared.u32('s_cp_idx')
# Number of samples that have been generated so far in this CTA
# If this number is negative, we're still fusing points, so this
# behaves slightly differently (see ``fuse_loop_start``)
mem.shared.u32('s_num_samples')
op.st.shared.u32(addr(s_num_samples), -(features.num_fuse_samples+1))
# TODO: temporary, for testing
reg.u32('num_rounds num_writes')
op.mov.u32(num_rounds, 0)
op.mov.u32(num_writes, 0)
# TODO: MWC float output types
mwc.next_f32_01(x_coord)
mwc.next_f32_01(y_coord)
reg.f32('x_coord y_coord color_coord')
mwc.next_f32_11(x_coord)
mwc.next_f32_11(y_coord)
mwc.next_f32_01(color_coord)
mwc.next_f32_01(alpha_coord)
# Registers are hard to come by. To avoid having to track both the count
# of samples processed and the number of samples to generate,
# 'num_samples' counts *down* from the CP's desired sample count.
# When it hits 0, we move on to the next CP.
#
# FUSE complicates things. To track it, we store the *negative* number
# of points we have left to fuse before we start to store the results.
# When it hits -1, we're done fusing, and can move on to the real
# thread. The execution flow between 'cp_loop', 'fuse_start', and
# 'iter_loop_start' is therefore tricky, and bears close inspection.
#
# In summary:
# num_samples == 0: Load next CP, set num_samples from that
# num_samples > 0: Iterate, store the result, decrement num_samples
# num_samples < -1: Iterate, don't store, increment num_samples
# num_samples == -1: Done fusing, enter normal flow
# TODO: move this to qlocal storage
reg.s32('num_samples')
op.mov.s32(num_samples, -(features.num_fuse_samples+1))
# TODO: Move cp_num to qlocal storage (or spill it, rarely accessed)
reg.u32('cp_idx cpA')
op.mov.u32(cp_idx, 0)
label('cp_loop_start')
comment("Ensure all init is done")
op.bar.sync(0)
with block('Check to see if this is the last CP'):
label('cp_loop_start')
reg.u32('cp_idx cpA')
with block("Claim a CP"):
std.set_is_first_thread(reg.pred('p_is_first'))
op.atom.inc.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
comment("Load the CP index in all threads")
op.bar.sync(0)
op.ld.shared.u32(cp_idx, addr(s_cp_idx))
with block("Check to see if this CP is valid (if not, we're done"):
reg.u32('num_cps')
reg.pred('p_last_cp')
op.ldu.u32(num_cps, addr(g_num_cps))
op.setp.ge.u32(p_last_cp, cp_idx, num_cps)
op.setp.ge.u32(p_last_cp, cp_idx, 1)
op.bra.uni('all_cps_done', ifp=p_last_cp)
with block('Load CP address'):
op.mov.u32(cpA, g_cp_array)
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
with block('Increment CP index, load num_samples (unless in fuse)'):
reg.pred('p_not_in_fuse')
op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
cp.get(cpA, num_samples, 'samples_per_thread',
ifp=p_not_in_fuse)
label('fuse_loop_start')
with block('FUSE-specific stuff'):
reg.pred('p_fuse')
comment('If num_samples == -1, set it to 0 and jump back up')
comment('This will start the normal CP loading machinery')
op.setp.eq.s32(p_fuse, num_samples, -1)
op.mov.s32(num_samples, 0, ifp=p_fuse)
op.bra.uni(cp_loop_start, ifp=p_fuse)
comment('If num_samples < -1, still fusing, so increment')
op.setp.lt.s32(p_fuse, num_samples, -1)
op.add.s32(num_samples, num_samples, 1, ifp=p_fuse)
# When fusing, num_samples holds the (negative) number of iterations
# left across the CP, rather than the number of samples in total.
with block("If still fusing, increment count unconditionally"):
std.set_is_first_thread(reg.pred('p_is_first'))
op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
op.bar.sync(0)
label('iter_loop_start')
@ -110,17 +94,33 @@ class IterThread(PTXTest):
op.add.u32(num_rounds, num_rounds, 1)
with block("Test if we're still in FUSE"):
reg.s32('num_samples')
reg.pred('p_in_fuse')
op.ld.shared.u32(num_samples, addr(s_num_samples))
op.setp.lt.s32(p_in_fuse, num_samples, 0)
op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
with block("Ordinarily, we'd write the result here"):
op.add.u32(num_writes, num_writes, 1)
# For testing, declare and clear p_badval
reg.pred('p_goodval')
op.setp.eq.u32(p_goodval, 1, 1)
with block("Increment number of samples by number of good values"):
reg.b32('good_samples')
op.vote.ballot.b32(good_samples, p_goodval)
op.popc.b32(good_samples, good_samples)
std.set_is_first_thread(reg.pred('p_is_first'))
op.red.shared.add.s32(addr(s_num_samples), good_samples,
ifp=p_is_first)
with block("Check to see if we're done with this CP"):
reg.pred('p_cp_done')
op.add.s32(num_samples, num_samples, -1)
op.setp.eq.s32(p_cp_done, num_samples, 0)
reg.s32('num_samples num_samples_needed')
op.ld.shared.s32(num_samples, addr(s_num_samples))
cp.get(cpA, num_samples_needed, 'cp.nsamples')
op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
op.bra.uni(cp_loop_start, ifp=p_cp_done)
op.bra.uni(iter_loop_start)
@ -134,13 +134,17 @@ class IterThread(PTXTest):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
assert len(cp_stream) <= cp_array_l, "Stream too big!"
cuda.memcpy_htod_async(cp_array_dp, cp_stream)
num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps')
cuda.memcpy_htod_async(num_cps_dp, struct.pack('i', num_cps))
cuda.memset_d32(num_cps_dp, num_cps, 1)
self.cps_uploaded = True
def call(self, ctx):
if not self.cps_uploaded:
raise Error("Cannot call IterThread before uploading CPs")
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
cuda.memset_d32(num_cps_st_dp, 0, 1)
func = ctx.mod.get_function('iter_thread')
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
@ -219,7 +223,7 @@ class MWCRNG(PTXFragment):
with block('Load random float [-1,1) into ' + dst_reg.name):
self._next()
op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
op.mul.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads:

View File

@ -650,6 +650,16 @@ class _PTXStdLib(PTXFragment):
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
op.st.b32(addr(spt_base), val)
@ptx_func
def set_is_first_thread(self, p_dst):
with block("Set %s if this is thread 0 in the CTA" % p_dst.name):
reg.u32('tid')
op.mov.u32(tid, '%tid.x')
op.setp.eq.u32(p_dst, tid, 0)
def not_(self, pred):
return ['!', pred]
def to_inject(self):
# Set up the initial namespace
return dict(

View File

@ -30,6 +30,9 @@ class Frame(pyflam3.Frame):
rw = cp.spatial_oversample * cp.width + 2 * self.filt.gutter
rh = cp.spatial_oversample * cp.height + 2 * self.filt.gutter
if cp.nbatches * cp.ntemporal_samples < ctx.ctas:
raise NotImplementedError(
"Distribution of a CP across multiple CTAs not yet done")
# Interpolate each time step, calculate per-step variables, and pack
# into the stream
cp_streamer = ctx.ptx.instances[CPDataStream]
@ -44,16 +47,14 @@ class Frame(pyflam3.Frame):
self.interpolate(time, tcp)
tcp.camera = Camera(self, tcp, self.filt)
# TODO: figure out which object to pack this into
nsamples = ((tcp.camera.sample_density * cp.width * cp.height) /
(cp.nbatches * cp.ntemporal_samples))
samples_per_thread = nsamples / ctx.threads + 15
tcp.nsamples = (tcp.camera.sample_density *
cp.width * cp.height) / (
cp.nbatches * cp.ntemporal_samples)
cp_streamer.pack_into(stream,
frame=self,
cp=tcp,
cp_idx=idx,
samples_per_thread=samples_per_thread)
cp_idx=idx)
stream.seek(0)
return (stream.read(), cp.nbatches * cp.ntemporal_samples)