mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Use shared memory for iter_count and have each CP processed by only one CTA.
Slower, but the code is a bit simpler conceptually, and the difference will be more than accounted for by better scheduling towards the end of the process.
This commit is contained in:
		@ -46,6 +46,10 @@ class LaunchContext(object):
 | 
			
		||||
    def ctas(self):
 | 
			
		||||
        return self.grid[0] * self.grid[1]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def threads_per_cta(self):
 | 
			
		||||
        return self.block[0] * self.block[1] * self.block[2]
 | 
			
		||||
 | 
			
		||||
    def compile(self, verbose=False, **kwargs):
 | 
			
		||||
        kwargs['ctx'] = self
 | 
			
		||||
        self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests)
 | 
			
		||||
 | 
			
		||||
@ -26,82 +26,66 @@ class IterThread(PTXTest):
 | 
			
		||||
        mem.global_.u32('g_cp_array',
 | 
			
		||||
                        cp.stream_size*features.max_ntemporal_samples)
 | 
			
		||||
        mem.global_.u32('g_num_cps')
 | 
			
		||||
        mem.global_.u32('g_num_cps_started')
 | 
			
		||||
        # TODO move into debug statement
 | 
			
		||||
        mem.global_.u32('g_num_rounds', ctx.threads)
 | 
			
		||||
        mem.global_.u32('g_num_writes', ctx.threads)
 | 
			
		||||
 | 
			
		||||
    @ptx_func
 | 
			
		||||
    def entry(self):
 | 
			
		||||
        reg.f32('x_coord y_coord color_coord alpha_coord')
 | 
			
		||||
        # For now, we indulge in the luxury of shared memory.
 | 
			
		||||
 | 
			
		||||
        # Index number of current CP, shared across CTA
 | 
			
		||||
        mem.shared.u32('s_cp_idx')
 | 
			
		||||
 | 
			
		||||
        # Number of samples that have been generated so far in this CTA
 | 
			
		||||
        # If this number is negative, we're still fusing points, so this
 | 
			
		||||
        # behaves slightly differently (see ``fuse_loop_start``)
 | 
			
		||||
        mem.shared.u32('s_num_samples')
 | 
			
		||||
        op.st.shared.u32(addr(s_num_samples), -(features.num_fuse_samples+1))
 | 
			
		||||
 | 
			
		||||
        # TODO: temporary, for testing
 | 
			
		||||
        reg.u32('num_rounds num_writes')
 | 
			
		||||
        op.mov.u32(num_rounds, 0)
 | 
			
		||||
        op.mov.u32(num_writes, 0)
 | 
			
		||||
 | 
			
		||||
        # TODO: MWC float output types
 | 
			
		||||
        mwc.next_f32_01(x_coord)
 | 
			
		||||
        mwc.next_f32_01(y_coord)
 | 
			
		||||
        reg.f32('x_coord y_coord color_coord')
 | 
			
		||||
        mwc.next_f32_11(x_coord)
 | 
			
		||||
        mwc.next_f32_11(y_coord)
 | 
			
		||||
        mwc.next_f32_01(color_coord)
 | 
			
		||||
        mwc.next_f32_01(alpha_coord)
 | 
			
		||||
 | 
			
		||||
        # Registers are hard to come by. To avoid having to track both the count
 | 
			
		||||
        # of samples processed and the number of samples to generate,
 | 
			
		||||
        # 'num_samples' counts *down* from the CP's desired sample count.
 | 
			
		||||
        # When it hits 0, we move on to the next CP.
 | 
			
		||||
        #
 | 
			
		||||
        # FUSE complicates things. To track it, we store the *negative* number
 | 
			
		||||
        # of points we have left to fuse before we start to store the results.
 | 
			
		||||
        # When it hits -1, we're done fusing, and can move on to the real
 | 
			
		||||
        # thread. The execution flow between 'cp_loop', 'fuse_start', and
 | 
			
		||||
        # 'iter_loop_start' is therefore tricky, and bears close inspection.
 | 
			
		||||
        #
 | 
			
		||||
        # In summary:
 | 
			
		||||
        #   num_samples == 0: Load next CP, set num_samples from that
 | 
			
		||||
        #   num_samples >  0: Iterate, store the result, decrement num_samples
 | 
			
		||||
        #   num_samples < -1: Iterate, don't store, increment num_samples
 | 
			
		||||
        #   num_samples == -1: Done fusing, enter normal flow
 | 
			
		||||
        # TODO: move this to qlocal storage
 | 
			
		||||
        reg.s32('num_samples')
 | 
			
		||||
        op.mov.s32(num_samples, -(features.num_fuse_samples+1))
 | 
			
		||||
 | 
			
		||||
        # TODO: Move cp_num to qlocal storage (or spill it, rarely accessed)
 | 
			
		||||
        reg.u32('cp_idx cpA')
 | 
			
		||||
        op.mov.u32(cp_idx, 0)
 | 
			
		||||
 | 
			
		||||
        label('cp_loop_start')
 | 
			
		||||
        comment("Ensure all init is done")
 | 
			
		||||
        op.bar.sync(0)
 | 
			
		||||
 | 
			
		||||
        with block('Check to see if this is the last CP'):
 | 
			
		||||
        label('cp_loop_start')
 | 
			
		||||
        reg.u32('cp_idx cpA')
 | 
			
		||||
        with block("Claim a CP"):
 | 
			
		||||
            std.set_is_first_thread(reg.pred('p_is_first'))
 | 
			
		||||
            op.atom.inc.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
 | 
			
		||||
            op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
 | 
			
		||||
 | 
			
		||||
        comment("Load the CP index in all threads")
 | 
			
		||||
        op.bar.sync(0)
 | 
			
		||||
        op.ld.shared.u32(cp_idx, addr(s_cp_idx))
 | 
			
		||||
 | 
			
		||||
        with block("Check to see if this CP is valid (if not, we're done"):
 | 
			
		||||
            reg.u32('num_cps')
 | 
			
		||||
            reg.pred('p_last_cp')
 | 
			
		||||
            op.ldu.u32(num_cps, addr(g_num_cps))
 | 
			
		||||
            op.setp.ge.u32(p_last_cp, cp_idx, num_cps)
 | 
			
		||||
            op.setp.ge.u32(p_last_cp, cp_idx, 1)
 | 
			
		||||
            op.bra.uni('all_cps_done', ifp=p_last_cp)
 | 
			
		||||
 | 
			
		||||
        with block('Load CP address'):
 | 
			
		||||
            op.mov.u32(cpA, g_cp_array)
 | 
			
		||||
            op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
 | 
			
		||||
 | 
			
		||||
        with block('Increment CP index, load num_samples (unless in fuse)'):
 | 
			
		||||
            reg.pred('p_not_in_fuse')
 | 
			
		||||
            op.setp.ge.s32(p_not_in_fuse, num_samples, 0)
 | 
			
		||||
            op.add.u32(cp_idx, cp_idx, 1, ifp=p_not_in_fuse)
 | 
			
		||||
            cp.get(cpA, num_samples, 'samples_per_thread',
 | 
			
		||||
                          ifp=p_not_in_fuse)
 | 
			
		||||
 | 
			
		||||
        label('fuse_loop_start')
 | 
			
		||||
        with block('FUSE-specific stuff'):
 | 
			
		||||
            reg.pred('p_fuse')
 | 
			
		||||
            comment('If num_samples == -1, set it to 0 and jump back up')
 | 
			
		||||
            comment('This will start the normal CP loading machinery')
 | 
			
		||||
            op.setp.eq.s32(p_fuse, num_samples, -1)
 | 
			
		||||
            op.mov.s32(num_samples, 0, ifp=p_fuse)
 | 
			
		||||
            op.bra.uni(cp_loop_start, ifp=p_fuse)
 | 
			
		||||
 | 
			
		||||
            comment('If num_samples < -1, still fusing, so increment')
 | 
			
		||||
            op.setp.lt.s32(p_fuse, num_samples, -1)
 | 
			
		||||
            op.add.s32(num_samples, num_samples, 1, ifp=p_fuse)
 | 
			
		||||
        # When fusing, num_samples holds the (negative) number of iterations
 | 
			
		||||
        # left across the CP, rather than the number of samples in total.
 | 
			
		||||
        with block("If still fusing, increment count unconditionally"):
 | 
			
		||||
            std.set_is_first_thread(reg.pred('p_is_first'))
 | 
			
		||||
            op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
 | 
			
		||||
            op.bar.sync(0)
 | 
			
		||||
 | 
			
		||||
        label('iter_loop_start')
 | 
			
		||||
 | 
			
		||||
@ -110,17 +94,33 @@ class IterThread(PTXTest):
 | 
			
		||||
        op.add.u32(num_rounds, num_rounds, 1)
 | 
			
		||||
 | 
			
		||||
        with block("Test if we're still in FUSE"):
 | 
			
		||||
            reg.s32('num_samples')
 | 
			
		||||
            reg.pred('p_in_fuse')
 | 
			
		||||
            op.ld.shared.u32(num_samples, addr(s_num_samples))
 | 
			
		||||
            op.setp.lt.s32(p_in_fuse, num_samples, 0)
 | 
			
		||||
            op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
 | 
			
		||||
 | 
			
		||||
        with block("Ordinarily, we'd write the result here"):
 | 
			
		||||
            op.add.u32(num_writes, num_writes, 1)
 | 
			
		||||
 | 
			
		||||
        # For testing, declare and clear p_badval
 | 
			
		||||
        reg.pred('p_goodval')
 | 
			
		||||
        op.setp.eq.u32(p_goodval, 1, 1)
 | 
			
		||||
 | 
			
		||||
        with block("Increment number of samples by number of good values"):
 | 
			
		||||
            reg.b32('good_samples')
 | 
			
		||||
            op.vote.ballot.b32(good_samples, p_goodval)
 | 
			
		||||
            op.popc.b32(good_samples, good_samples)
 | 
			
		||||
            std.set_is_first_thread(reg.pred('p_is_first'))
 | 
			
		||||
            op.red.shared.add.s32(addr(s_num_samples), good_samples,
 | 
			
		||||
                                  ifp=p_is_first)
 | 
			
		||||
 | 
			
		||||
        with block("Check to see if we're done with this CP"):
 | 
			
		||||
            reg.pred('p_cp_done')
 | 
			
		||||
            op.add.s32(num_samples, num_samples, -1)
 | 
			
		||||
            op.setp.eq.s32(p_cp_done, num_samples, 0)
 | 
			
		||||
            reg.s32('num_samples num_samples_needed')
 | 
			
		||||
            op.ld.shared.s32(num_samples, addr(s_num_samples))
 | 
			
		||||
            cp.get(cpA, num_samples_needed, 'cp.nsamples')
 | 
			
		||||
            op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
 | 
			
		||||
            op.bra.uni(cp_loop_start, ifp=p_cp_done)
 | 
			
		||||
 | 
			
		||||
        op.bra.uni(iter_loop_start)
 | 
			
		||||
@ -134,13 +134,17 @@ class IterThread(PTXTest):
 | 
			
		||||
        cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
 | 
			
		||||
        assert len(cp_stream) <= cp_array_l, "Stream too big!"
 | 
			
		||||
        cuda.memcpy_htod_async(cp_array_dp, cp_stream)
 | 
			
		||||
 | 
			
		||||
        num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps')
 | 
			
		||||
        cuda.memcpy_htod_async(num_cps_dp, struct.pack('i', num_cps))
 | 
			
		||||
        cuda.memset_d32(num_cps_dp, num_cps, 1)
 | 
			
		||||
        self.cps_uploaded = True
 | 
			
		||||
 | 
			
		||||
    def call(self, ctx):
 | 
			
		||||
        if not self.cps_uploaded:
 | 
			
		||||
            raise Error("Cannot call IterThread before uploading CPs")
 | 
			
		||||
        num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
 | 
			
		||||
        cuda.memset_d32(num_cps_st_dp, 0, 1)
 | 
			
		||||
 | 
			
		||||
        func = ctx.mod.get_function('iter_thread')
 | 
			
		||||
        dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
 | 
			
		||||
 | 
			
		||||
@ -219,7 +223,7 @@ class MWCRNG(PTXFragment):
 | 
			
		||||
        with block('Load random float [-1,1) into ' + dst_reg.name):
 | 
			
		||||
            self._next()
 | 
			
		||||
            op.cvt.rn.f32.s32(dst_reg, mwc_st)
 | 
			
		||||
            op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
 | 
			
		||||
            op.mul.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
 | 
			
		||||
 | 
			
		||||
    def device_init(self, ctx):
 | 
			
		||||
        if self.threads_ready >= ctx.threads:
 | 
			
		||||
 | 
			
		||||
@ -650,6 +650,16 @@ class _PTXStdLib(PTXFragment):
 | 
			
		||||
            op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
 | 
			
		||||
            op.st.b32(addr(spt_base), val)
 | 
			
		||||
 | 
			
		||||
    @ptx_func
 | 
			
		||||
    def set_is_first_thread(self, p_dst):
 | 
			
		||||
        with block("Set %s if this is thread 0 in the CTA" % p_dst.name):
 | 
			
		||||
            reg.u32('tid')
 | 
			
		||||
            op.mov.u32(tid, '%tid.x')
 | 
			
		||||
            op.setp.eq.u32(p_dst, tid, 0)
 | 
			
		||||
 | 
			
		||||
    def not_(self, pred):
 | 
			
		||||
        return ['!', pred]
 | 
			
		||||
 | 
			
		||||
    def to_inject(self):
 | 
			
		||||
        # Set up the initial namespace
 | 
			
		||||
        return dict(
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,9 @@ class Frame(pyflam3.Frame):
 | 
			
		||||
        rw = cp.spatial_oversample * cp.width  + 2 * self.filt.gutter
 | 
			
		||||
        rh = cp.spatial_oversample * cp.height + 2 * self.filt.gutter
 | 
			
		||||
 | 
			
		||||
        if cp.nbatches * cp.ntemporal_samples < ctx.ctas:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                "Distribution of a CP across multiple CTAs not yet done")
 | 
			
		||||
        # Interpolate each time step, calculate per-step variables, and pack
 | 
			
		||||
        # into the stream
 | 
			
		||||
        cp_streamer = ctx.ptx.instances[CPDataStream]
 | 
			
		||||
@ -44,16 +47,14 @@ class Frame(pyflam3.Frame):
 | 
			
		||||
                self.interpolate(time, tcp)
 | 
			
		||||
                tcp.camera = Camera(self, tcp, self.filt)
 | 
			
		||||
 | 
			
		||||
                # TODO: figure out which object to pack this into
 | 
			
		||||
                nsamples = ((tcp.camera.sample_density * cp.width * cp.height) /
 | 
			
		||||
                            (cp.nbatches * cp.ntemporal_samples))
 | 
			
		||||
                samples_per_thread = nsamples / ctx.threads + 15
 | 
			
		||||
                tcp.nsamples = (tcp.camera.sample_density *
 | 
			
		||||
                                cp.width * cp.height) / (
 | 
			
		||||
                                cp.nbatches * cp.ntemporal_samples)
 | 
			
		||||
 | 
			
		||||
                cp_streamer.pack_into(stream,
 | 
			
		||||
                        frame=self,
 | 
			
		||||
                        cp=tcp,
 | 
			
		||||
                        cp_idx=idx,
 | 
			
		||||
                        samples_per_thread=samples_per_thread)
 | 
			
		||||
                        cp_idx=idx)
 | 
			
		||||
        stream.seek(0)
 | 
			
		||||
        return (stream.read(), cp.nbatches * cp.ntemporal_samples)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user