Use shared memory for iter_count and have each CP processed by only one CTA.

Slower, but the code is a bit simpler conceptually, and the difference will be
more than accounted for by better scheduling towards the end of the process.
This commit is contained in:
Steven Robertson 2010-09-07 14:54:50 -04:00
parent aa065dc25d
commit 094890c324
4 changed files with 79 additions and 60 deletions

View File

@ -46,6 +46,10 @@ class LaunchContext(object):
def ctas(self): def ctas(self):
return self.grid[0] * self.grid[1] return self.grid[0] * self.grid[1]
@property
def threads_per_cta(self):
return self.block[0] * self.block[1] * self.block[2]
def compile(self, verbose=False, **kwargs): def compile(self, verbose=False, **kwargs):
kwargs['ctx'] = self kwargs['ctx'] = self
self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests) self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests)

View File

@ -26,82 +26,66 @@ class IterThread(PTXTest):
mem.global_.u32('g_cp_array', mem.global_.u32('g_cp_array',
cp.stream_size*features.max_ntemporal_samples) cp.stream_size*features.max_ntemporal_samples)
mem.global_.u32('g_num_cps') mem.global_.u32('g_num_cps')
mem.global_.u32('g_num_cps_started')
# TODO move into debug statement # TODO move into debug statement
mem.global_.u32('g_num_rounds', ctx.threads) mem.global_.u32('g_num_rounds', ctx.threads)
mem.global_.u32('g_num_writes', ctx.threads) mem.global_.u32('g_num_writes', ctx.threads)
@ptx_func @ptx_func
def entry(self): def entry(self):
reg.f32('x_coord y_coord color_coord alpha_coord') # For now, we indulge in the luxury of shared memory.
# Index number of current CP, shared across CTA
mem.shared.u32('s_cp_idx')
# Number of samples that have been generated so far in this CTA
# If this number is negative, we're still fusing points, so this
# behaves slightly differently (see ``fuse_loop_start``)
mem.shared.u32('s_num_samples')
op.st.shared.u32(addr(s_num_samples), -(features.num_fuse_samples+1))
# TODO: temporary, for testing # TODO: temporary, for testing
reg.u32('num_rounds num_writes') reg.u32('num_rounds num_writes')
op.mov.u32(num_rounds, 0) op.mov.u32(num_rounds, 0)
op.mov.u32(num_writes, 0) op.mov.u32(num_writes, 0)
# TODO: MWC float output types reg.f32('x_coord y_coord color_coord')
mwc.next_f32_01(x_coord) mwc.next_f32_11(x_coord)
mwc.next_f32_01(y_coord) mwc.next_f32_11(y_coord)
mwc.next_f32_01(color_coord) mwc.next_f32_01(color_coord)
mwc.next_f32_01(alpha_coord)
# Registers are hard to come by. To avoid having to track both the count comment("Ensure all init is done")
# of samples processed and the number of samples to generate,
# 'num_samples' counts *down* from the CP's desired sample count.
# When it hits 0, we move on to the next CP.
#
# FUSE complicates things. To track it, we store the *negative* number
# of points we have left to fuse before we start to store the results.
# When it hits -1, we're done fusing, and can move on to the real
# thread. The execution flow between 'cp_loop', 'fuse_start', and
# 'iter_loop_start' is therefore tricky, and bears close inspection.
#
# In summary:
# num_samples == 0: Load next CP, set num_samples from that
# num_samples > 0: Iterate, store the result, decrement num_samples
# num_samples < -1: Iterate, don't store, increment num_samples
# num_samples == -1: Done fusing, enter normal flow
# TODO: move this to qlocal storage
reg.s32('num_samples')
op.mov.s32(num_samples, -(features.num_fuse_samples+1))
# TODO: Move cp_num to qlocal storage (or spill it, rarely accessed)
reg.u32('cp_idx cpA')
op.mov.u32(cp_idx, 0)
label('cp_loop_start')
op.bar.sync(0) op.bar.sync(0)
with block('Check to see if this is the last CP'): label('cp_loop_start')
reg.u32('cp_idx cpA')
with block("Claim a CP"):
std.set_is_first_thread(reg.pred('p_is_first'))
op.atom.inc.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
comment("Load the CP index in all threads")
op.bar.sync(0)
op.ld.shared.u32(cp_idx, addr(s_cp_idx))
with block("Check to see if this CP is valid (if not, we're done"):
reg.u32('num_cps') reg.u32('num_cps')
reg.pred('p_last_cp') reg.pred('p_last_cp')
op.ldu.u32(num_cps, addr(g_num_cps)) op.ldu.u32(num_cps, addr(g_num_cps))
op.setp.ge.u32(p_last_cp, cp_idx, num_cps) op.setp.ge.u32(p_last_cp, cp_idx, 1)
op.bra.uni('all_cps_done', ifp=p_last_cp) op.bra.uni('all_cps_done', ifp=p_last_cp)
with block('Load CP address'): with block('Load CP address'):
op.mov.u32(cpA, g_cp_array) op.mov.u32(cpA, g_cp_array)
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA) op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
with block('Increment CP index, load num_samples (unless in fuse)'):
reg.pred('p_not_in_fuse')
op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
cp.get(cpA, num_samples, 'samples_per_thread',
ifp=p_not_in_fuse)
label('fuse_loop_start') label('fuse_loop_start')
with block('FUSE-specific stuff'): # When fusing, num_samples holds the (negative) number of iterations
reg.pred('p_fuse') # left across the CP, rather than the number of samples in total.
comment('If num_samples == -1, set it to 0 and jump back up') with block("If still fusing, increment count unconditionally"):
comment('This will start the normal CP loading machinery') std.set_is_first_thread(reg.pred('p_is_first'))
op.setp.eq.s32(p_fuse, num_samples, -1) op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
op.mov.s32(num_samples, 0, ifp=p_fuse) op.bar.sync(0)
op.bra.uni(cp_loop_start, ifp=p_fuse)
comment('If num_samples < -1, still fusing, so increment')
op.setp.lt.s32(p_fuse, num_samples, -1)
op.add.s32(num_samples, num_samples, 1, ifp=p_fuse)
label('iter_loop_start') label('iter_loop_start')
@ -110,17 +94,33 @@ class IterThread(PTXTest):
op.add.u32(num_rounds, num_rounds, 1) op.add.u32(num_rounds, num_rounds, 1)
with block("Test if we're still in FUSE"): with block("Test if we're still in FUSE"):
reg.s32('num_samples')
reg.pred('p_in_fuse') reg.pred('p_in_fuse')
op.ld.shared.u32(num_samples, addr(s_num_samples))
op.setp.lt.s32(p_in_fuse, num_samples, 0) op.setp.lt.s32(p_in_fuse, num_samples, 0)
op.bra.uni(fuse_loop_start, ifp=p_in_fuse) op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
with block("Ordinarily, we'd write the result here"): with block("Ordinarily, we'd write the result here"):
op.add.u32(num_writes, num_writes, 1) op.add.u32(num_writes, num_writes, 1)
# For testing, declare and clear p_badval
reg.pred('p_goodval')
op.setp.eq.u32(p_goodval, 1, 1)
with block("Increment number of samples by number of good values"):
reg.b32('good_samples')
op.vote.ballot.b32(good_samples, p_goodval)
op.popc.b32(good_samples, good_samples)
std.set_is_first_thread(reg.pred('p_is_first'))
op.red.shared.add.s32(addr(s_num_samples), good_samples,
ifp=p_is_first)
with block("Check to see if we're done with this CP"): with block("Check to see if we're done with this CP"):
reg.pred('p_cp_done') reg.pred('p_cp_done')
op.add.s32(num_samples, num_samples, -1) reg.s32('num_samples num_samples_needed')
op.setp.eq.s32(p_cp_done, num_samples, 0) op.ld.shared.s32(num_samples, addr(s_num_samples))
cp.get(cpA, num_samples_needed, 'cp.nsamples')
op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
op.bra.uni(cp_loop_start, ifp=p_cp_done) op.bra.uni(cp_loop_start, ifp=p_cp_done)
op.bra.uni(iter_loop_start) op.bra.uni(iter_loop_start)
@ -134,13 +134,17 @@ class IterThread(PTXTest):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array') cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
assert len(cp_stream) <= cp_array_l, "Stream too big!" assert len(cp_stream) <= cp_array_l, "Stream too big!"
cuda.memcpy_htod_async(cp_array_dp, cp_stream) cuda.memcpy_htod_async(cp_array_dp, cp_stream)
num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps') num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps')
cuda.memcpy_htod_async(num_cps_dp, struct.pack('i', num_cps)) cuda.memset_d32(num_cps_dp, num_cps, 1)
self.cps_uploaded = True self.cps_uploaded = True
def call(self, ctx): def call(self, ctx):
if not self.cps_uploaded: if not self.cps_uploaded:
raise Error("Cannot call IterThread before uploading CPs") raise Error("Cannot call IterThread before uploading CPs")
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
cuda.memset_d32(num_cps_st_dp, 0, 1)
func = ctx.mod.get_function('iter_thread') func = ctx.mod.get_function('iter_thread')
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True) dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
@ -219,7 +223,7 @@ class MWCRNG(PTXFragment):
with block('Load random float [-1,1) into ' + dst_reg.name): with block('Load random float [-1,1) into ' + dst_reg.name):
self._next() self._next()
op.cvt.rn.f32.s32(dst_reg, mwc_st) op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31) op.mul.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
def device_init(self, ctx): def device_init(self, ctx):
if self.threads_ready >= ctx.threads: if self.threads_ready >= ctx.threads:

View File

@ -650,6 +650,16 @@ class _PTXStdLib(PTXFragment):
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base) op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
op.st.b32(addr(spt_base), val) op.st.b32(addr(spt_base), val)
@ptx_func
def set_is_first_thread(self, p_dst):
with block("Set %s if this is thread 0 in the CTA" % p_dst.name):
reg.u32('tid')
op.mov.u32(tid, '%tid.x')
op.setp.eq.u32(p_dst, tid, 0)
def not_(self, pred):
return ['!', pred]
def to_inject(self): def to_inject(self):
# Set up the initial namespace # Set up the initial namespace
return dict( return dict(

View File

@ -30,6 +30,9 @@ class Frame(pyflam3.Frame):
rw = cp.spatial_oversample * cp.width + 2 * self.filt.gutter rw = cp.spatial_oversample * cp.width + 2 * self.filt.gutter
rh = cp.spatial_oversample * cp.height + 2 * self.filt.gutter rh = cp.spatial_oversample * cp.height + 2 * self.filt.gutter
if cp.nbatches * cp.ntemporal_samples < ctx.ctas:
raise NotImplementedError(
"Distribution of a CP across multiple CTAs not yet done")
# Interpolate each time step, calculate per-step variables, and pack # Interpolate each time step, calculate per-step variables, and pack
# into the stream # into the stream
cp_streamer = ctx.ptx.instances[CPDataStream] cp_streamer = ctx.ptx.instances[CPDataStream]
@ -44,16 +47,14 @@ class Frame(pyflam3.Frame):
self.interpolate(time, tcp) self.interpolate(time, tcp)
tcp.camera = Camera(self, tcp, self.filt) tcp.camera = Camera(self, tcp, self.filt)
# TODO: figure out which object to pack this into tcp.nsamples = (tcp.camera.sample_density *
nsamples = ((tcp.camera.sample_density * cp.width * cp.height) / cp.width * cp.height) / (
(cp.nbatches * cp.ntemporal_samples)) cp.nbatches * cp.ntemporal_samples)
samples_per_thread = nsamples / ctx.threads + 15
cp_streamer.pack_into(stream, cp_streamer.pack_into(stream,
frame=self, frame=self,
cp=tcp, cp=tcp,
cp_idx=idx, cp_idx=idx)
samples_per_thread=samples_per_thread)
stream.seek(0) stream.seek(0)
return (stream.read(), cp.nbatches * cp.ntemporal_samples) return (stream.read(), cp.nbatches * cp.ntemporal_samples)