A new (somewhat experimental) approach to fusing

This commit is contained in:
Steven Robertson 2010-09-12 23:45:38 -04:00
parent 5a5fcf5bb9
commit e0b218feba
4 changed files with 41 additions and 65 deletions

View File

@ -44,8 +44,6 @@ class IterThread(PTXEntryPoint):
# behaves slightly differently (see ``fuse_loop_start``)
# TODO: replace (or at least simplify) this logic
mem.shared.s32('s_num_samples')
op.st.shared.s32(addr(s_num_samples), -(features.num_fuse_samples+1))
mem.shared.f32('s_xf_sel', ctx.warps_per_cta)
# TODO: temporary, for testing
@ -54,11 +52,11 @@ class IterThread(PTXEntryPoint):
op.st.local.u32(addr(l_num_rounds), 0)
op.st.local.u32(addr(l_num_writes), 0)
reg.f32('xi xo yi yo colori coloro consec_bad')
mwc.next_f32_11(xi)
mwc.next_f32_11(yi)
mwc.next_f32_01(colori)
op.mov.f32(consec_bad, 0.)
reg.f32('x y color consec_bad')
mwc.next_f32_11(x)
mwc.next_f32_11(y)
mwc.next_f32_01(color)
op.mov.f32(consec_bad, float(-features.fuse))
comment("Ensure all init is done")
op.bar.sync(0)
@ -71,13 +69,7 @@ class IterThread(PTXEntryPoint):
std.set_is_first_thread(reg.pred('p_is_first'))
op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
with block("If done fusing, reset the sample count now"):
reg.pred("p_done_fusing")
reg.s32('num_samples')
op.ld.shared.s32(num_samples, addr(s_num_samples))
op.setp.gt.s32(p_done_fusing, num_samples, 0)
op.st.shared.s32(addr(s_num_samples), 0, ifp=p_done_fusing)
op.st.shared.s32(addr(s_num_samples), 0)
comment("Load the CP index in all threads")
op.bar.sync(0)
@ -96,14 +88,6 @@ class IterThread(PTXEntryPoint):
label('fuse_loop_start')
# When fusing, num_samples holds the (negative) number of iterations
# left across the CP, rather than the number of samples in total.
with block("If still fusing, increment count unconditionally"):
op.bar.sync(0)
std.set_is_first_thread(reg.pred('p_is_first'))
op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
label('iter_loop_choose_xform')
with block("Choose the xform for each warp"):
timeout.check_time(5)
@ -147,51 +131,42 @@ class IterThread(PTXEntryPoint):
for xf in features.xforms:
label('XFORM_%d' % xf.id)
variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id)
variations.apply_xform(x, y, color, x, y, color, xf.id)
op.bra("xform_done")
label("xform_done")
with block("Test if we're still in FUSE"):
reg.s32('num_samples')
reg.pred('p_in_fuse')
op.ld.shared.s32(num_samples, addr(s_num_samples))
op.setp.lt.s32(p_in_fuse, num_samples, 0)
op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
reg.pred('p_point_is_valid')
reg.pred('p_valid_pt')
with block("Write the result"):
hist.scatter(xo, yo, coloro, 0, p_point_is_valid, 'ldst')
reg.u32('hist_index')
camera.get_index(hist_index, x, y, p_valid_pt)
comment('if consec_bad < 0, point is fusing; treat as invalid')
op.setp.and_.ge.f32(p_valid_pt, consec_bad, 0., p_valid_pt)
# TODO: save and pass correct color value here
hist.scatter(hist_index, color, 0, p_valid_pt, 'ldst')
with block():
reg.u32('num_writes')
op.ld.local.u32(num_writes, addr(l_num_writes))
op.add.u32(num_writes, num_writes, 1, ifp=p_point_is_valid)
op.add.u32(num_writes, num_writes, 1, ifp=p_valid_pt)
op.st.local.u32(addr(l_num_writes), num_writes)
with block("If the result was invalid, handle badvals"):
reg.f32('consec')
reg.pred('need_new_point')
comment('If point is good, move new coords and reset consec_bad')
op.mov.f32(xi, xo, ifp=p_point_is_valid)
op.mov.f32(yi, yo, ifp=p_point_is_valid)
op.mov.f32(colori, coloro, ifp=p_point_is_valid)
op.mov.f32(consec_bad, 0., ifp=p_point_is_valid)
comment('Otherwise, add 1 to consec_bad')
op.add.f32(consec, consec, 1., ifnotp=p_point_is_valid)
op.setp.ge.f32(need_new_point, consec, 5.)
op.add.f32(consec_bad, consec_bad, 1., ifnotp=p_valid_pt)
op.setp.ge.f32(need_new_point, consec_bad, float(features.max_bad))
op.bra('badval_done', ifnotp=need_new_point)
comment('If consec_bad > 5, pick a new random point')
mwc.next_f32_11(xi)
mwc.next_f32_11(yi)
mwc.next_f32_01(colori)
op.mov.f32(consec, 0.)
mwc.next_f32_11(x)
mwc.next_f32_11(y)
mwc.next_f32_01(color)
op.mov.f32(consec_bad, float(-features.fuse))
label('badval_done')
with block("Increment number of samples by number of good values"):
reg.b32('good_samples laneid')
reg.pred('p_is_first')
op.vote.ballot.b32(good_samples, p_point_is_valid)
op.vote.ballot.b32(good_samples, p_valid_pt)
op.popc.b32(good_samples, good_samples)
op.mov.u32(laneid, '%laneid')
op.setp.eq.u32(p_is_first, laneid, 0)
@ -209,7 +184,7 @@ class IterThread(PTXEntryPoint):
op.bra.uni(cp_loop_start, ifp=p_cp_done)
comment('Shuffle points between threads')
shuf.shuffle(xi, yi, colori, consec_bad)
shuf.shuffle(x, y, color, consec_bad)
with block("If in first warp, pick new offset"):
reg.u32('tid')
@ -390,14 +365,14 @@ class PaletteLookup(PTXFragment):
mem.global_.texref('t_palette')
@ptx_func
def look_up(self, r, g, b, a, color, norm_time):
def look_up(self, r, g, b, a, color, norm_time, ifp):
"""
Look up the values of ``r, g, b, a`` corresponding to ``color_coord``
at the CP indexed in ``timestamp_idx``. Note that both ``color_coord``
and ``timestamp_idx`` should be [0,1]-normalized floats.
"""
op.tex._2d.v4.f32.f32(vec(r, g, b, a),
addr([t_palette, ', ', vec(norm_time, color)]))
addr([t_palette, ', ', vec(norm_time, color)]), ifp=ifp)
if features.non_box_temporal_filter:
raise NotImplementedError("Non-box temporal filters not supported")
@ -462,17 +437,13 @@ class HistScatter(PTXFragment):
op.st.volatile.b32(addr(g_hist_dummy), hist_dummy)
@ptx_func
def scatter(self, x, y, color, xf_idx, p_valid=None, type='ldst'):
def scatter(self, hist_index, color, xf_idx, p_valid, type='ldst'):
"""
Scatter the given point directly to the histogram bins. I think this
technique has the worst performance of all of 'em. Accesses ``cpA``
directly.
"""
with block("Scatter directly to buffer"):
if p_valid is None:
p_valid = reg.pred('p_valid')
reg.u32('hist_index')
camera.get_index(hist_index, x, y, p_valid)
reg.u32('hist_bin_addr')
op.mov.u32(hist_bin_addr, g_hist_bins)
op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr)
@ -484,20 +455,22 @@ class HistScatter(PTXFragment):
reg.f32('r g b a norm_time')
cp.get(cpA, norm_time, 'cp.norm_time')
palette.look_up(r, g, b, a, color, norm_time)
palette.look_up(r, g, b, a, color, norm_time, ifp=p_valid)
# TODO: look up, scale by xform visibility
# TODO: Make this more performant
if type == 'ldst':
reg.f32('gr gg gb ga')
op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr))
op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr),
ifp=p_valid)
op.add.f32(gr, gr, r)
op.add.f32(gg, gg, g)
op.add.f32(gb, gb, b)
op.add.f32(ga, ga, a)
op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga))
op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga),
ifp=p_valid)
elif type == 'red':
for i, val in enumerate([r, g, b, a]):
op.red.add.f32(addr(hist_bin_addr,4*i), val)
op.red.add.f32(addr(hist_bin_addr,4*i), val, ifp=p_valid)
elif type == 'fake':
op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr)
op.st.local.f32(addr(l_scatter_fake_alpha), a)

View File

@ -213,8 +213,11 @@ class Features(object):
Determine features and constants required to render a particular set of
genomes. The values of this class are fixed before compilation begins.
"""
# Constant; number of rounds spent fusing points on first CP of a frame
num_fuse_samples = 25
# Constant parameters which control handling of out-of-frame samples:
# Number of iterations to iterate without write after new point
fuse = 2
# Maximum consecutive out-of-frame points before picking new point
max_bad = 3
def __init__(self, genomes, flt):
any = lambda l: bool(filter(None, map(l, genomes)))

View File

@ -54,8 +54,8 @@ class Variations(PTXFragment):
"""
Apply a transform.
This function makes a copy of the input variables, so it's safe to use
the same registers for input and output.
This function necessarily makes a copy of the input variables, so it's
safe to use the same registers for input and output.
"""
with block("Apply xform %d" % xform_idx):
self.xform_idx = xform_idx
@ -123,7 +123,7 @@ class Variations(PTXFragment):
@ptx_func
def spherical(self, xo, yo, xi, yi, wgt):
reg.f32('r2')
op.fma.rn.ftz.f32(r2, xi, xi, '1e-9')
op.fma.rn.ftz.f32(r2, xi, xi, '1e-30')
op.fma.rn.ftz.f32(r2, yi, yi, r2)
op.rcp.approx.f32(r2, r2)
op.mul.rn.ftz.f32(r2, r2, wgt)

View File

@ -64,7 +64,7 @@ def main(args):
anim.features.hist_height,
'RGBA',
bins.tostring(),
anim.features.hist_stride*4)
-anim.features.hist_stride*4)
tex = image.texture
pal = (anim.ctx.ptx.instances[PaletteLookup].pal * 255.).astype(np.uint8)