diff --git a/cuburn/device_code.py b/cuburn/device_code.py index 1215752..3028deb 100644 --- a/cuburn/device_code.py +++ b/cuburn/device_code.py @@ -44,8 +44,6 @@ class IterThread(PTXEntryPoint): # behaves slightly differently (see ``fuse_loop_start``) # TODO: replace (or at least simplify) this logic mem.shared.s32('s_num_samples') - op.st.shared.s32(addr(s_num_samples), -(features.num_fuse_samples+1)) - mem.shared.f32('s_xf_sel', ctx.warps_per_cta) # TODO: temporary, for testing @@ -54,11 +52,11 @@ class IterThread(PTXEntryPoint): op.st.local.u32(addr(l_num_rounds), 0) op.st.local.u32(addr(l_num_writes), 0) - reg.f32('xi xo yi yo colori coloro consec_bad') - mwc.next_f32_11(xi) - mwc.next_f32_11(yi) - mwc.next_f32_01(colori) - op.mov.f32(consec_bad, 0.) + reg.f32('x y color consec_bad') + mwc.next_f32_11(x) + mwc.next_f32_11(y) + mwc.next_f32_01(color) + op.mov.f32(consec_bad, float(-features.fuse)) comment("Ensure all init is done") op.bar.sync(0) @@ -71,13 +69,7 @@ class IterThread(PTXEntryPoint): std.set_is_first_thread(reg.pred('p_is_first')) op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first) op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first) - - with block("If done fusing, reset the sample count now"): - reg.pred("p_done_fusing") - reg.s32('num_samples') - op.ld.shared.s32(num_samples, addr(s_num_samples)) - op.setp.gt.s32(p_done_fusing, num_samples, 0) - op.st.shared.s32(addr(s_num_samples), 0, ifp=p_done_fusing) + op.st.shared.s32(addr(s_num_samples), 0) comment("Load the CP index in all threads") op.bar.sync(0) @@ -96,14 +88,6 @@ class IterThread(PTXEntryPoint): - label('fuse_loop_start') - # When fusing, num_samples holds the (negative) number of iterations - # left across the CP, rather than the number of samples in total. - with block("If still fusing, increment count unconditionally"): - op.bar.sync(0) - std.set_is_first_thread(reg.pred('p_is_first')) - op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first) - label('iter_loop_choose_xform') with block("Choose the xform for each warp"): timeout.check_time(5) @@ -147,51 +131,42 @@ class IterThread(PTXEntryPoint): for xf in features.xforms: label('XFORM_%d' % xf.id) - variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id) + variations.apply_xform(x, y, color, x, y, color, xf.id) op.bra("xform_done") label("xform_done") - with block("Test if we're still in FUSE"): - reg.s32('num_samples') - reg.pred('p_in_fuse') - op.ld.shared.s32(num_samples, addr(s_num_samples)) - op.setp.lt.s32(p_in_fuse, num_samples, 0) - op.bra.uni(fuse_loop_start, ifp=p_in_fuse) - reg.pred('p_point_is_valid') + reg.pred('p_valid_pt') with block("Write the result"): - hist.scatter(xo, yo, coloro, 0, p_point_is_valid, 'ldst') + reg.u32('hist_index') + camera.get_index(hist_index, x, y, p_valid_pt) + comment('if consec_bad < 0, point is fusing; treat as invalid') + op.setp.and_.ge.f32(p_valid_pt, consec_bad, 0., p_valid_pt) + # TODO: save and pass correct color value here + hist.scatter(hist_index, color, 0, p_valid_pt, 'ldst') with block(): reg.u32('num_writes') op.ld.local.u32(num_writes, addr(l_num_writes)) - op.add.u32(num_writes, num_writes, 1, ifp=p_point_is_valid) + op.add.u32(num_writes, num_writes, 1, ifp=p_valid_pt) op.st.local.u32(addr(l_num_writes), num_writes) with block("If the result was invalid, handle badvals"): - reg.f32('consec') reg.pred('need_new_point') - comment('If point is good, move new coords and reset consec_bad') - op.mov.f32(xi, xo, ifp=p_point_is_valid) - op.mov.f32(yi, yo, ifp=p_point_is_valid) - op.mov.f32(colori, coloro, ifp=p_point_is_valid) - op.mov.f32(consec_bad, 0., ifp=p_point_is_valid) - - comment('Otherwise, add 1 to consec_bad') - op.add.f32(consec, consec, 1., ifnotp=p_point_is_valid) - op.setp.ge.f32(need_new_point, consec, 5.) + op.add.f32(consec_bad, consec_bad, 1., ifnotp=p_valid_pt) + op.setp.ge.f32(need_new_point, consec_bad, float(features.max_bad)) op.bra('badval_done', ifnotp=need_new_point) comment('If consec_bad > 5, pick a new random point') - mwc.next_f32_11(xi) - mwc.next_f32_11(yi) - mwc.next_f32_01(colori) - op.mov.f32(consec, 0.) + mwc.next_f32_11(x) + mwc.next_f32_11(y) + mwc.next_f32_01(color) + op.mov.f32(consec_bad, float(-features.fuse)) label('badval_done') with block("Increment number of samples by number of good values"): reg.b32('good_samples laneid') reg.pred('p_is_first') - op.vote.ballot.b32(good_samples, p_point_is_valid) + op.vote.ballot.b32(good_samples, p_valid_pt) op.popc.b32(good_samples, good_samples) op.mov.u32(laneid, '%laneid') op.setp.eq.u32(p_is_first, laneid, 0) @@ -209,7 +184,7 @@ class IterThread(PTXEntryPoint): op.bra.uni(cp_loop_start, ifp=p_cp_done) comment('Shuffle points between threads') - shuf.shuffle(xi, yi, colori, consec_bad) + shuf.shuffle(x, y, color, consec_bad) with block("If in first warp, pick new offset"): reg.u32('tid') @@ -390,14 +365,14 @@ class PaletteLookup(PTXFragment): mem.global_.texref('t_palette') @ptx_func - def look_up(self, r, g, b, a, color, norm_time): + def look_up(self, r, g, b, a, color, norm_time, ifp): """ Look up the values of ``r, g, b, a`` corresponding to ``color_coord`` at the CP indexed in ``timestamp_idx``. Note that both ``color_coord`` and ``timestamp_idx`` should be [0,1]-normalized floats. """ op.tex._2d.v4.f32.f32(vec(r, g, b, a), - addr([t_palette, ', ', vec(norm_time, color)])) + addr([t_palette, ', ', vec(norm_time, color)]), ifp=ifp) if features.non_box_temporal_filter: raise NotImplementedError("Non-box temporal filters not supported") @@ -462,17 +437,13 @@ class HistScatter(PTXFragment): op.st.volatile.b32(addr(g_hist_dummy), hist_dummy) @ptx_func - def scatter(self, x, y, color, xf_idx, p_valid=None, type='ldst'): + def scatter(self, hist_index, color, xf_idx, p_valid, type='ldst'): """ Scatter the given point directly to the histogram bins. I think this technique has the worst performance of all of 'em. Accesses ``cpA`` directly. """ with block("Scatter directly to buffer"): - if p_valid is None: - p_valid = reg.pred('p_valid') - reg.u32('hist_index') - camera.get_index(hist_index, x, y, p_valid) reg.u32('hist_bin_addr') op.mov.u32(hist_bin_addr, g_hist_bins) op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr) @@ -484,20 +455,22 @@ class HistScatter(PTXFragment): reg.f32('r g b a norm_time') cp.get(cpA, norm_time, 'cp.norm_time') - palette.look_up(r, g, b, a, color, norm_time) + palette.look_up(r, g, b, a, color, norm_time, ifp=p_valid) # TODO: look up, scale by xform visibility # TODO: Make this more performant if type == 'ldst': reg.f32('gr gg gb ga') - op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr)) + op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr), + ifp=p_valid) op.add.f32(gr, gr, r) op.add.f32(gg, gg, g) op.add.f32(gb, gb, b) op.add.f32(ga, ga, a) - op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga)) + op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga), + ifp=p_valid) elif type == 'red': for i, val in enumerate([r, g, b, a]): - op.red.add.f32(addr(hist_bin_addr,4*i), val) + op.red.add.f32(addr(hist_bin_addr,4*i), val, ifp=p_valid) elif type == 'fake': op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr) op.st.local.f32(addr(l_scatter_fake_alpha), a) diff --git a/cuburn/render.py b/cuburn/render.py index 943b348..27cd00d 100644 --- a/cuburn/render.py +++ b/cuburn/render.py @@ -213,8 +213,11 @@ class Features(object): Determine features and constants required to render a particular set of genomes. The values of this class are fixed before compilation begins. """ - # Constant; number of rounds spent fusing points on first CP of a frame - num_fuse_samples = 25 + # Constant parameters which control handling of out-of-frame samples: + # Number of iterations to iterate without write after new point + fuse = 2 + # Maximum consecutive out-of-frame points before picking new point + max_bad = 3 def __init__(self, genomes, flt): any = lambda l: bool(filter(None, map(l, genomes))) diff --git a/cuburn/variations.py b/cuburn/variations.py index 483c772..6aad6bf 100644 --- a/cuburn/variations.py +++ b/cuburn/variations.py @@ -54,8 +54,8 @@ class Variations(PTXFragment): """ Apply a transform. - This function makes a copy of the input variables, so it's safe to use - the same registers for input and output. + This function necessarily makes a copy of the input variables, so it's + safe to use the same registers for input and output. """ with block("Apply xform %d" % xform_idx): self.xform_idx = xform_idx @@ -123,7 +123,7 @@ class Variations(PTXFragment): @ptx_func def spherical(self, xo, yo, xi, yi, wgt): reg.f32('r2') - op.fma.rn.ftz.f32(r2, xi, xi, '1e-9') + op.fma.rn.ftz.f32(r2, xi, xi, '1e-30') op.fma.rn.ftz.f32(r2, yi, yi, r2) op.rcp.approx.f32(r2, r2) op.mul.rn.ftz.f32(r2, r2, wgt) diff --git a/main.py b/main.py index aadf638..9c55a00 100644 --- a/main.py +++ b/main.py @@ -64,7 +64,7 @@ def main(args): anim.features.hist_height, 'RGBA', bins.tostring(), - anim.features.hist_stride*4) + -anim.features.hist_stride*4) tex = image.texture pal = (anim.ctx.ptx.instances[PaletteLookup].pal * 255.).astype(np.uint8)