mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-04 02:10:45 -05:00 
			
		
		
		
	A new (somewhat experimental) approach to fusing
This commit is contained in:
		@ -44,8 +44,6 @@ class IterThread(PTXEntryPoint):
 | 
			
		||||
        # behaves slightly differently (see ``fuse_loop_start``)
 | 
			
		||||
        # TODO: replace (or at least simplify) this logic
 | 
			
		||||
        mem.shared.s32('s_num_samples')
 | 
			
		||||
        op.st.shared.s32(addr(s_num_samples), -(features.num_fuse_samples+1))
 | 
			
		||||
 | 
			
		||||
        mem.shared.f32('s_xf_sel', ctx.warps_per_cta)
 | 
			
		||||
 | 
			
		||||
        # TODO: temporary, for testing
 | 
			
		||||
@ -54,11 +52,11 @@ class IterThread(PTXEntryPoint):
 | 
			
		||||
        op.st.local.u32(addr(l_num_rounds), 0)
 | 
			
		||||
        op.st.local.u32(addr(l_num_writes), 0)
 | 
			
		||||
 | 
			
		||||
        reg.f32('xi xo yi yo colori coloro consec_bad')
 | 
			
		||||
        mwc.next_f32_11(xi)
 | 
			
		||||
        mwc.next_f32_11(yi)
 | 
			
		||||
        mwc.next_f32_01(colori)
 | 
			
		||||
        op.mov.f32(consec_bad, 0.)
 | 
			
		||||
        reg.f32('x y color consec_bad')
 | 
			
		||||
        mwc.next_f32_11(x)
 | 
			
		||||
        mwc.next_f32_11(y)
 | 
			
		||||
        mwc.next_f32_01(color)
 | 
			
		||||
        op.mov.f32(consec_bad, float(-features.fuse))
 | 
			
		||||
 | 
			
		||||
        comment("Ensure all init is done")
 | 
			
		||||
        op.bar.sync(0)
 | 
			
		||||
@ -71,13 +69,7 @@ class IterThread(PTXEntryPoint):
 | 
			
		||||
            std.set_is_first_thread(reg.pred('p_is_first'))
 | 
			
		||||
            op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
 | 
			
		||||
            op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
 | 
			
		||||
 | 
			
		||||
        with block("If done fusing, reset the sample count now"):
 | 
			
		||||
            reg.pred("p_done_fusing")
 | 
			
		||||
            reg.s32('num_samples')
 | 
			
		||||
            op.ld.shared.s32(num_samples, addr(s_num_samples))
 | 
			
		||||
            op.setp.gt.s32(p_done_fusing, num_samples, 0)
 | 
			
		||||
            op.st.shared.s32(addr(s_num_samples), 0, ifp=p_done_fusing)
 | 
			
		||||
            op.st.shared.s32(addr(s_num_samples), 0)
 | 
			
		||||
 | 
			
		||||
        comment("Load the CP index in all threads")
 | 
			
		||||
        op.bar.sync(0)
 | 
			
		||||
@ -96,14 +88,6 @@ class IterThread(PTXEntryPoint):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        label('fuse_loop_start')
 | 
			
		||||
        # When fusing, num_samples holds the (negative) number of iterations
 | 
			
		||||
        # left across the CP, rather than the number of samples in total.
 | 
			
		||||
        with block("If still fusing, increment count unconditionally"):
 | 
			
		||||
            op.bar.sync(0)
 | 
			
		||||
            std.set_is_first_thread(reg.pred('p_is_first'))
 | 
			
		||||
            op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
 | 
			
		||||
 | 
			
		||||
        label('iter_loop_choose_xform')
 | 
			
		||||
        with block("Choose the xform for each warp"):
 | 
			
		||||
            timeout.check_time(5)
 | 
			
		||||
@ -147,51 +131,42 @@ class IterThread(PTXEntryPoint):
 | 
			
		||||
 | 
			
		||||
            for xf in features.xforms:
 | 
			
		||||
                label('XFORM_%d' % xf.id)
 | 
			
		||||
                variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id)
 | 
			
		||||
                variations.apply_xform(x, y, color, x, y, color, xf.id)
 | 
			
		||||
                op.bra("xform_done")
 | 
			
		||||
 | 
			
		||||
        label("xform_done")
 | 
			
		||||
        with block("Test if we're still in FUSE"):
 | 
			
		||||
            reg.s32('num_samples')
 | 
			
		||||
            reg.pred('p_in_fuse')
 | 
			
		||||
            op.ld.shared.s32(num_samples, addr(s_num_samples))
 | 
			
		||||
            op.setp.lt.s32(p_in_fuse, num_samples, 0)
 | 
			
		||||
            op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
 | 
			
		||||
 | 
			
		||||
        reg.pred('p_point_is_valid')
 | 
			
		||||
        reg.pred('p_valid_pt')
 | 
			
		||||
        with block("Write the result"):
 | 
			
		||||
            hist.scatter(xo, yo, coloro, 0, p_point_is_valid, 'ldst')
 | 
			
		||||
            reg.u32('hist_index')
 | 
			
		||||
            camera.get_index(hist_index, x, y, p_valid_pt)
 | 
			
		||||
            comment('if consec_bad < 0, point is fusing; treat as invalid')
 | 
			
		||||
            op.setp.and_.ge.f32(p_valid_pt, consec_bad, 0., p_valid_pt)
 | 
			
		||||
            # TODO: save and pass correct color value here
 | 
			
		||||
            hist.scatter(hist_index, color, 0, p_valid_pt, 'ldst')
 | 
			
		||||
            with block():
 | 
			
		||||
                reg.u32('num_writes')
 | 
			
		||||
                op.ld.local.u32(num_writes, addr(l_num_writes))
 | 
			
		||||
                op.add.u32(num_writes, num_writes, 1, ifp=p_point_is_valid)
 | 
			
		||||
                op.add.u32(num_writes, num_writes, 1, ifp=p_valid_pt)
 | 
			
		||||
                op.st.local.u32(addr(l_num_writes), num_writes)
 | 
			
		||||
 | 
			
		||||
        with block("If the result was invalid, handle badvals"):
 | 
			
		||||
            reg.f32('consec')
 | 
			
		||||
            reg.pred('need_new_point')
 | 
			
		||||
            comment('If point is good, move new coords and reset consec_bad')
 | 
			
		||||
            op.mov.f32(xi, xo, ifp=p_point_is_valid)
 | 
			
		||||
            op.mov.f32(yi, yo, ifp=p_point_is_valid)
 | 
			
		||||
            op.mov.f32(colori, coloro, ifp=p_point_is_valid)
 | 
			
		||||
            op.mov.f32(consec_bad, 0., ifp=p_point_is_valid)
 | 
			
		||||
 | 
			
		||||
            comment('Otherwise, add 1 to consec_bad')
 | 
			
		||||
            op.add.f32(consec, consec, 1., ifnotp=p_point_is_valid)
 | 
			
		||||
            op.setp.ge.f32(need_new_point, consec, 5.)
 | 
			
		||||
            op.add.f32(consec_bad, consec_bad, 1., ifnotp=p_valid_pt)
 | 
			
		||||
            op.setp.ge.f32(need_new_point, consec_bad, float(features.max_bad))
 | 
			
		||||
            op.bra('badval_done', ifnotp=need_new_point)
 | 
			
		||||
 | 
			
		||||
            comment('If consec_bad > 5, pick a new random point')
 | 
			
		||||
            mwc.next_f32_11(xi)
 | 
			
		||||
            mwc.next_f32_11(yi)
 | 
			
		||||
            mwc.next_f32_01(colori)
 | 
			
		||||
            op.mov.f32(consec, 0.)
 | 
			
		||||
            mwc.next_f32_11(x)
 | 
			
		||||
            mwc.next_f32_11(y)
 | 
			
		||||
            mwc.next_f32_01(color)
 | 
			
		||||
            op.mov.f32(consec_bad, float(-features.fuse))
 | 
			
		||||
            label('badval_done')
 | 
			
		||||
 | 
			
		||||
        with block("Increment number of samples by number of good values"):
 | 
			
		||||
            reg.b32('good_samples laneid')
 | 
			
		||||
            reg.pred('p_is_first')
 | 
			
		||||
            op.vote.ballot.b32(good_samples, p_point_is_valid)
 | 
			
		||||
            op.vote.ballot.b32(good_samples, p_valid_pt)
 | 
			
		||||
            op.popc.b32(good_samples, good_samples)
 | 
			
		||||
            op.mov.u32(laneid, '%laneid')
 | 
			
		||||
            op.setp.eq.u32(p_is_first, laneid, 0)
 | 
			
		||||
@ -209,7 +184,7 @@ class IterThread(PTXEntryPoint):
 | 
			
		||||
            op.bra.uni(cp_loop_start, ifp=p_cp_done)
 | 
			
		||||
 | 
			
		||||
        comment('Shuffle points between threads')
 | 
			
		||||
        shuf.shuffle(xi, yi, colori, consec_bad)
 | 
			
		||||
        shuf.shuffle(x, y, color, consec_bad)
 | 
			
		||||
 | 
			
		||||
        with block("If in first warp, pick new offset"):
 | 
			
		||||
            reg.u32('tid')
 | 
			
		||||
@ -390,14 +365,14 @@ class PaletteLookup(PTXFragment):
 | 
			
		||||
        mem.global_.texref('t_palette')
 | 
			
		||||
 | 
			
		||||
    @ptx_func
 | 
			
		||||
    def look_up(self, r, g, b, a, color, norm_time):
 | 
			
		||||
    def look_up(self, r, g, b, a, color, norm_time, ifp):
 | 
			
		||||
        """
 | 
			
		||||
        Look up the values of ``r, g, b, a`` corresponding to ``color_coord``
 | 
			
		||||
        at the CP indexed in ``timestamp_idx``. Note that both ``color_coord``
 | 
			
		||||
        and ``timestamp_idx`` should be [0,1]-normalized floats.
 | 
			
		||||
        """
 | 
			
		||||
        op.tex._2d.v4.f32.f32(vec(r, g, b, a),
 | 
			
		||||
                addr([t_palette, ', ',  vec(norm_time, color)]))
 | 
			
		||||
                addr([t_palette, ', ',  vec(norm_time, color)]), ifp=ifp)
 | 
			
		||||
        if features.non_box_temporal_filter:
 | 
			
		||||
            raise NotImplementedError("Non-box temporal filters not supported")
 | 
			
		||||
 | 
			
		||||
@ -462,17 +437,13 @@ class HistScatter(PTXFragment):
 | 
			
		||||
            op.st.volatile.b32(addr(g_hist_dummy), hist_dummy)
 | 
			
		||||
 | 
			
		||||
    @ptx_func
 | 
			
		||||
    def scatter(self, x, y, color, xf_idx, p_valid=None, type='ldst'):
 | 
			
		||||
    def scatter(self, hist_index, color, xf_idx, p_valid, type='ldst'):
 | 
			
		||||
        """
 | 
			
		||||
        Scatter the given point directly to the histogram bins. I think this
 | 
			
		||||
        technique has the worst performance of all of 'em. Accesses ``cpA``
 | 
			
		||||
        directly.
 | 
			
		||||
        """
 | 
			
		||||
        with block("Scatter directly to buffer"):
 | 
			
		||||
            if p_valid is None:
 | 
			
		||||
                p_valid = reg.pred('p_valid')
 | 
			
		||||
            reg.u32('hist_index')
 | 
			
		||||
            camera.get_index(hist_index, x, y, p_valid)
 | 
			
		||||
            reg.u32('hist_bin_addr')
 | 
			
		||||
            op.mov.u32(hist_bin_addr, g_hist_bins)
 | 
			
		||||
            op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr)
 | 
			
		||||
@ -484,20 +455,22 @@ class HistScatter(PTXFragment):
 | 
			
		||||
 | 
			
		||||
            reg.f32('r g b a norm_time')
 | 
			
		||||
            cp.get(cpA, norm_time, 'cp.norm_time')
 | 
			
		||||
            palette.look_up(r, g, b, a, color, norm_time)
 | 
			
		||||
            palette.look_up(r, g, b, a, color, norm_time, ifp=p_valid)
 | 
			
		||||
            # TODO: look up, scale by xform visibility
 | 
			
		||||
            # TODO: Make this more performant
 | 
			
		||||
            if type == 'ldst':
 | 
			
		||||
                reg.f32('gr gg gb ga')
 | 
			
		||||
                op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr))
 | 
			
		||||
                op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr),
 | 
			
		||||
                             ifp=p_valid)
 | 
			
		||||
                op.add.f32(gr, gr, r)
 | 
			
		||||
                op.add.f32(gg, gg, g)
 | 
			
		||||
                op.add.f32(gb, gb, b)
 | 
			
		||||
                op.add.f32(ga, ga, a)
 | 
			
		||||
                op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga))
 | 
			
		||||
                op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga),
 | 
			
		||||
                             ifp=p_valid)
 | 
			
		||||
            elif type == 'red':
 | 
			
		||||
                for i, val in enumerate([r, g, b, a]):
 | 
			
		||||
                    op.red.add.f32(addr(hist_bin_addr,4*i), val)
 | 
			
		||||
                    op.red.add.f32(addr(hist_bin_addr,4*i), val, ifp=p_valid)
 | 
			
		||||
            elif type == 'fake':
 | 
			
		||||
                op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr)
 | 
			
		||||
                op.st.local.f32(addr(l_scatter_fake_alpha), a)
 | 
			
		||||
 | 
			
		||||
@ -213,8 +213,11 @@ class Features(object):
 | 
			
		||||
    Determine features and constants required to render a particular set of
 | 
			
		||||
    genomes. The values of this class are fixed before compilation begins.
 | 
			
		||||
    """
 | 
			
		||||
    # Constant; number of rounds spent fusing points on first CP of a frame
 | 
			
		||||
    num_fuse_samples = 25
 | 
			
		||||
    # Constant parameters which control handling of out-of-frame samples:
 | 
			
		||||
    # Number of iterations to iterate without write after new point
 | 
			
		||||
    fuse = 2
 | 
			
		||||
    # Maximum consecutive out-of-frame points before picking new point
 | 
			
		||||
    max_bad = 3
 | 
			
		||||
 | 
			
		||||
    def __init__(self, genomes, flt):
 | 
			
		||||
        any = lambda l: bool(filter(None, map(l, genomes)))
 | 
			
		||||
 | 
			
		||||
@ -54,8 +54,8 @@ class Variations(PTXFragment):
 | 
			
		||||
        """
 | 
			
		||||
        Apply a transform.
 | 
			
		||||
 | 
			
		||||
        This function makes a copy of the input variables, so it's safe to use
 | 
			
		||||
        the same registers for input and output.
 | 
			
		||||
        This function necessarily makes a copy of the input variables, so it's
 | 
			
		||||
        safe to use the same registers for input and output.
 | 
			
		||||
        """
 | 
			
		||||
        with block("Apply xform %d" % xform_idx):
 | 
			
		||||
            self.xform_idx = xform_idx
 | 
			
		||||
@ -123,7 +123,7 @@ class Variations(PTXFragment):
 | 
			
		||||
    @ptx_func
 | 
			
		||||
    def spherical(self, xo, yo, xi, yi, wgt):
 | 
			
		||||
        reg.f32('r2')
 | 
			
		||||
        op.fma.rn.ftz.f32(r2, xi, xi, '1e-9')
 | 
			
		||||
        op.fma.rn.ftz.f32(r2, xi, xi, '1e-30')
 | 
			
		||||
        op.fma.rn.ftz.f32(r2, yi, yi, r2)
 | 
			
		||||
        op.rcp.approx.f32(r2, r2)
 | 
			
		||||
        op.mul.rn.ftz.f32(r2, r2, wgt)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user