mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Use shared memory for iter_count and have each CP processed by only one CTA.
Slower, but the code is a bit simpler conceptually, and the difference will be more than accounted for by better scheduling towards the end of the process.
This commit is contained in:
parent
aa065dc25d
commit
094890c324
@ -46,6 +46,10 @@ class LaunchContext(object):
|
||||
def ctas(self):
|
||||
return self.grid[0] * self.grid[1]
|
||||
|
||||
@property
|
||||
def threads_per_cta(self):
|
||||
return self.block[0] * self.block[1] * self.block[2]
|
||||
|
||||
def compile(self, verbose=False, **kwargs):
|
||||
kwargs['ctx'] = self
|
||||
self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests)
|
||||
|
@ -26,82 +26,66 @@ class IterThread(PTXTest):
|
||||
mem.global_.u32('g_cp_array',
|
||||
cp.stream_size*features.max_ntemporal_samples)
|
||||
mem.global_.u32('g_num_cps')
|
||||
mem.global_.u32('g_num_cps_started')
|
||||
# TODO move into debug statement
|
||||
mem.global_.u32('g_num_rounds', ctx.threads)
|
||||
mem.global_.u32('g_num_writes', ctx.threads)
|
||||
|
||||
@ptx_func
|
||||
def entry(self):
|
||||
reg.f32('x_coord y_coord color_coord alpha_coord')
|
||||
# For now, we indulge in the luxury of shared memory.
|
||||
|
||||
# Index number of current CP, shared across CTA
|
||||
mem.shared.u32('s_cp_idx')
|
||||
|
||||
# Number of samples that have been generated so far in this CTA
|
||||
# If this number is negative, we're still fusing points, so this
|
||||
# behaves slightly differently (see ``fuse_loop_start``)
|
||||
mem.shared.u32('s_num_samples')
|
||||
op.st.shared.u32(addr(s_num_samples), -(features.num_fuse_samples+1))
|
||||
|
||||
# TODO: temporary, for testing
|
||||
reg.u32('num_rounds num_writes')
|
||||
op.mov.u32(num_rounds, 0)
|
||||
op.mov.u32(num_writes, 0)
|
||||
|
||||
# TODO: MWC float output types
|
||||
mwc.next_f32_01(x_coord)
|
||||
mwc.next_f32_01(y_coord)
|
||||
reg.f32('x_coord y_coord color_coord')
|
||||
mwc.next_f32_11(x_coord)
|
||||
mwc.next_f32_11(y_coord)
|
||||
mwc.next_f32_01(color_coord)
|
||||
mwc.next_f32_01(alpha_coord)
|
||||
|
||||
# Registers are hard to come by. To avoid having to track both the count
|
||||
# of samples processed and the number of samples to generate,
|
||||
# 'num_samples' counts *down* from the CP's desired sample count.
|
||||
# When it hits 0, we move on to the next CP.
|
||||
#
|
||||
# FUSE complicates things. To track it, we store the *negative* number
|
||||
# of points we have left to fuse before we start to store the results.
|
||||
# When it hits -1, we're done fusing, and can move on to the real
|
||||
# thread. The execution flow between 'cp_loop', 'fuse_start', and
|
||||
# 'iter_loop_start' is therefore tricky, and bears close inspection.
|
||||
#
|
||||
# In summary:
|
||||
# num_samples == 0: Load next CP, set num_samples from that
|
||||
# num_samples > 0: Iterate, store the result, decrement num_samples
|
||||
# num_samples < -1: Iterate, don't store, increment num_samples
|
||||
# num_samples == -1: Done fusing, enter normal flow
|
||||
# TODO: move this to qlocal storage
|
||||
reg.s32('num_samples')
|
||||
op.mov.s32(num_samples, -(features.num_fuse_samples+1))
|
||||
|
||||
# TODO: Move cp_num to qlocal storage (or spill it, rarely accessed)
|
||||
reg.u32('cp_idx cpA')
|
||||
op.mov.u32(cp_idx, 0)
|
||||
|
||||
label('cp_loop_start')
|
||||
comment("Ensure all init is done")
|
||||
op.bar.sync(0)
|
||||
|
||||
with block('Check to see if this is the last CP'):
|
||||
label('cp_loop_start')
|
||||
reg.u32('cp_idx cpA')
|
||||
with block("Claim a CP"):
|
||||
std.set_is_first_thread(reg.pred('p_is_first'))
|
||||
op.atom.inc.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
|
||||
op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
|
||||
|
||||
comment("Load the CP index in all threads")
|
||||
op.bar.sync(0)
|
||||
op.ld.shared.u32(cp_idx, addr(s_cp_idx))
|
||||
|
||||
with block("Check to see if this CP is valid (if not, we're done"):
|
||||
reg.u32('num_cps')
|
||||
reg.pred('p_last_cp')
|
||||
op.ldu.u32(num_cps, addr(g_num_cps))
|
||||
op.setp.ge.u32(p_last_cp, cp_idx, num_cps)
|
||||
op.setp.ge.u32(p_last_cp, cp_idx, 1)
|
||||
op.bra.uni('all_cps_done', ifp=p_last_cp)
|
||||
|
||||
with block('Load CP address'):
|
||||
op.mov.u32(cpA, g_cp_array)
|
||||
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
|
||||
|
||||
with block('Increment CP index, load num_samples (unless in fuse)'):
|
||||
reg.pred('p_not_in_fuse')
|
||||
op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
|
||||
op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
|
||||
cp.get(cpA, num_samples, 'samples_per_thread',
|
||||
ifp=p_not_in_fuse)
|
||||
|
||||
label('fuse_loop_start')
|
||||
with block('FUSE-specific stuff'):
|
||||
reg.pred('p_fuse')
|
||||
comment('If num_samples == -1, set it to 0 and jump back up')
|
||||
comment('This will start the normal CP loading machinery')
|
||||
op.setp.eq.s32(p_fuse, num_samples, -1)
|
||||
op.mov.s32(num_samples, 0, ifp=p_fuse)
|
||||
op.bra.uni(cp_loop_start, ifp=p_fuse)
|
||||
|
||||
comment('If num_samples < -1, still fusing, so increment')
|
||||
op.setp.lt.s32(p_fuse, num_samples, -1)
|
||||
op.add.s32(num_samples, num_samples, 1, ifp=p_fuse)
|
||||
# When fusing, num_samples holds the (negative) number of iterations
|
||||
# left across the CP, rather than the number of samples in total.
|
||||
with block("If still fusing, increment count unconditionally"):
|
||||
std.set_is_first_thread(reg.pred('p_is_first'))
|
||||
op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
|
||||
op.bar.sync(0)
|
||||
|
||||
label('iter_loop_start')
|
||||
|
||||
@ -110,17 +94,33 @@ class IterThread(PTXTest):
|
||||
op.add.u32(num_rounds, num_rounds, 1)
|
||||
|
||||
with block("Test if we're still in FUSE"):
|
||||
reg.s32('num_samples')
|
||||
reg.pred('p_in_fuse')
|
||||
op.ld.shared.u32(num_samples, addr(s_num_samples))
|
||||
op.setp.lt.s32(p_in_fuse, num_samples, 0)
|
||||
op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
|
||||
|
||||
with block("Ordinarily, we'd write the result here"):
|
||||
op.add.u32(num_writes, num_writes, 1)
|
||||
|
||||
# For testing, declare and clear p_badval
|
||||
reg.pred('p_goodval')
|
||||
op.setp.eq.u32(p_goodval, 1, 1)
|
||||
|
||||
with block("Increment number of samples by number of good values"):
|
||||
reg.b32('good_samples')
|
||||
op.vote.ballot.b32(good_samples, p_goodval)
|
||||
op.popc.b32(good_samples, good_samples)
|
||||
std.set_is_first_thread(reg.pred('p_is_first'))
|
||||
op.red.shared.add.s32(addr(s_num_samples), good_samples,
|
||||
ifp=p_is_first)
|
||||
|
||||
with block("Check to see if we're done with this CP"):
|
||||
reg.pred('p_cp_done')
|
||||
op.add.s32(num_samples, num_samples, -1)
|
||||
op.setp.eq.s32(p_cp_done, num_samples, 0)
|
||||
reg.s32('num_samples num_samples_needed')
|
||||
op.ld.shared.s32(num_samples, addr(s_num_samples))
|
||||
cp.get(cpA, num_samples_needed, 'cp.nsamples')
|
||||
op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
|
||||
op.bra.uni(cp_loop_start, ifp=p_cp_done)
|
||||
|
||||
op.bra.uni(iter_loop_start)
|
||||
@ -134,13 +134,17 @@ class IterThread(PTXTest):
|
||||
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
|
||||
assert len(cp_stream) <= cp_array_l, "Stream too big!"
|
||||
cuda.memcpy_htod_async(cp_array_dp, cp_stream)
|
||||
|
||||
num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps')
|
||||
cuda.memcpy_htod_async(num_cps_dp, struct.pack('i', num_cps))
|
||||
cuda.memset_d32(num_cps_dp, num_cps, 1)
|
||||
self.cps_uploaded = True
|
||||
|
||||
def call(self, ctx):
|
||||
if not self.cps_uploaded:
|
||||
raise Error("Cannot call IterThread before uploading CPs")
|
||||
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
|
||||
cuda.memset_d32(num_cps_st_dp, 0, 1)
|
||||
|
||||
func = ctx.mod.get_function('iter_thread')
|
||||
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
|
||||
|
||||
@ -219,7 +223,7 @@ class MWCRNG(PTXFragment):
|
||||
with block('Load random float [-1,1) into ' + dst_reg.name):
|
||||
self._next()
|
||||
op.cvt.rn.f32.s32(dst_reg, mwc_st)
|
||||
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
|
||||
op.mul.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
|
||||
|
||||
def device_init(self, ctx):
|
||||
if self.threads_ready >= ctx.threads:
|
||||
|
@ -650,6 +650,16 @@ class _PTXStdLib(PTXFragment):
|
||||
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
|
||||
op.st.b32(addr(spt_base), val)
|
||||
|
||||
@ptx_func
|
||||
def set_is_first_thread(self, p_dst):
|
||||
with block("Set %s if this is thread 0 in the CTA" % p_dst.name):
|
||||
reg.u32('tid')
|
||||
op.mov.u32(tid, '%tid.x')
|
||||
op.setp.eq.u32(p_dst, tid, 0)
|
||||
|
||||
def not_(self, pred):
|
||||
return ['!', pred]
|
||||
|
||||
def to_inject(self):
|
||||
# Set up the initial namespace
|
||||
return dict(
|
||||
|
@ -30,6 +30,9 @@ class Frame(pyflam3.Frame):
|
||||
rw = cp.spatial_oversample * cp.width + 2 * self.filt.gutter
|
||||
rh = cp.spatial_oversample * cp.height + 2 * self.filt.gutter
|
||||
|
||||
if cp.nbatches * cp.ntemporal_samples < ctx.ctas:
|
||||
raise NotImplementedError(
|
||||
"Distribution of a CP across multiple CTAs not yet done")
|
||||
# Interpolate each time step, calculate per-step variables, and pack
|
||||
# into the stream
|
||||
cp_streamer = ctx.ptx.instances[CPDataStream]
|
||||
@ -44,16 +47,14 @@ class Frame(pyflam3.Frame):
|
||||
self.interpolate(time, tcp)
|
||||
tcp.camera = Camera(self, tcp, self.filt)
|
||||
|
||||
# TODO: figure out which object to pack this into
|
||||
nsamples = ((tcp.camera.sample_density * cp.width * cp.height) /
|
||||
(cp.nbatches * cp.ntemporal_samples))
|
||||
samples_per_thread = nsamples / ctx.threads + 15
|
||||
tcp.nsamples = (tcp.camera.sample_density *
|
||||
cp.width * cp.height) / (
|
||||
cp.nbatches * cp.ntemporal_samples)
|
||||
|
||||
cp_streamer.pack_into(stream,
|
||||
frame=self,
|
||||
cp=tcp,
|
||||
cp_idx=idx,
|
||||
samples_per_thread=samples_per_thread)
|
||||
cp_idx=idx)
|
||||
stream.seek(0)
|
||||
return (stream.read(), cp.nbatches * cp.ntemporal_samples)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user