mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
A new (somewhat experimental) approach to fusing
This commit is contained in:
parent
5a5fcf5bb9
commit
e0b218feba
@ -44,8 +44,6 @@ class IterThread(PTXEntryPoint):
|
||||
# behaves slightly differently (see ``fuse_loop_start``)
|
||||
# TODO: replace (or at least simplify) this logic
|
||||
mem.shared.s32('s_num_samples')
|
||||
op.st.shared.s32(addr(s_num_samples), -(features.num_fuse_samples+1))
|
||||
|
||||
mem.shared.f32('s_xf_sel', ctx.warps_per_cta)
|
||||
|
||||
# TODO: temporary, for testing
|
||||
@ -54,11 +52,11 @@ class IterThread(PTXEntryPoint):
|
||||
op.st.local.u32(addr(l_num_rounds), 0)
|
||||
op.st.local.u32(addr(l_num_writes), 0)
|
||||
|
||||
reg.f32('xi xo yi yo colori coloro consec_bad')
|
||||
mwc.next_f32_11(xi)
|
||||
mwc.next_f32_11(yi)
|
||||
mwc.next_f32_01(colori)
|
||||
op.mov.f32(consec_bad, 0.)
|
||||
reg.f32('x y color consec_bad')
|
||||
mwc.next_f32_11(x)
|
||||
mwc.next_f32_11(y)
|
||||
mwc.next_f32_01(color)
|
||||
op.mov.f32(consec_bad, float(-features.fuse))
|
||||
|
||||
comment("Ensure all init is done")
|
||||
op.bar.sync(0)
|
||||
@ -71,13 +69,7 @@ class IterThread(PTXEntryPoint):
|
||||
std.set_is_first_thread(reg.pred('p_is_first'))
|
||||
op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
|
||||
op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
|
||||
|
||||
with block("If done fusing, reset the sample count now"):
|
||||
reg.pred("p_done_fusing")
|
||||
reg.s32('num_samples')
|
||||
op.ld.shared.s32(num_samples, addr(s_num_samples))
|
||||
op.setp.gt.s32(p_done_fusing, num_samples, 0)
|
||||
op.st.shared.s32(addr(s_num_samples), 0, ifp=p_done_fusing)
|
||||
op.st.shared.s32(addr(s_num_samples), 0)
|
||||
|
||||
comment("Load the CP index in all threads")
|
||||
op.bar.sync(0)
|
||||
@ -96,14 +88,6 @@ class IterThread(PTXEntryPoint):
|
||||
|
||||
|
||||
|
||||
label('fuse_loop_start')
|
||||
# When fusing, num_samples holds the (negative) number of iterations
|
||||
# left across the CP, rather than the number of samples in total.
|
||||
with block("If still fusing, increment count unconditionally"):
|
||||
op.bar.sync(0)
|
||||
std.set_is_first_thread(reg.pred('p_is_first'))
|
||||
op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
|
||||
|
||||
label('iter_loop_choose_xform')
|
||||
with block("Choose the xform for each warp"):
|
||||
timeout.check_time(5)
|
||||
@ -147,51 +131,42 @@ class IterThread(PTXEntryPoint):
|
||||
|
||||
for xf in features.xforms:
|
||||
label('XFORM_%d' % xf.id)
|
||||
variations.apply_xform(xo, yo, coloro, xi, yi, colori, xf.id)
|
||||
variations.apply_xform(x, y, color, x, y, color, xf.id)
|
||||
op.bra("xform_done")
|
||||
|
||||
label("xform_done")
|
||||
with block("Test if we're still in FUSE"):
|
||||
reg.s32('num_samples')
|
||||
reg.pred('p_in_fuse')
|
||||
op.ld.shared.s32(num_samples, addr(s_num_samples))
|
||||
op.setp.lt.s32(p_in_fuse, num_samples, 0)
|
||||
op.bra.uni(fuse_loop_start, ifp=p_in_fuse)
|
||||
|
||||
reg.pred('p_point_is_valid')
|
||||
reg.pred('p_valid_pt')
|
||||
with block("Write the result"):
|
||||
hist.scatter(xo, yo, coloro, 0, p_point_is_valid, 'ldst')
|
||||
reg.u32('hist_index')
|
||||
camera.get_index(hist_index, x, y, p_valid_pt)
|
||||
comment('if consec_bad < 0, point is fusing; treat as invalid')
|
||||
op.setp.and_.ge.f32(p_valid_pt, consec_bad, 0., p_valid_pt)
|
||||
# TODO: save and pass correct color value here
|
||||
hist.scatter(hist_index, color, 0, p_valid_pt, 'ldst')
|
||||
with block():
|
||||
reg.u32('num_writes')
|
||||
op.ld.local.u32(num_writes, addr(l_num_writes))
|
||||
op.add.u32(num_writes, num_writes, 1, ifp=p_point_is_valid)
|
||||
op.add.u32(num_writes, num_writes, 1, ifp=p_valid_pt)
|
||||
op.st.local.u32(addr(l_num_writes), num_writes)
|
||||
|
||||
with block("If the result was invalid, handle badvals"):
|
||||
reg.f32('consec')
|
||||
reg.pred('need_new_point')
|
||||
comment('If point is good, move new coords and reset consec_bad')
|
||||
op.mov.f32(xi, xo, ifp=p_point_is_valid)
|
||||
op.mov.f32(yi, yo, ifp=p_point_is_valid)
|
||||
op.mov.f32(colori, coloro, ifp=p_point_is_valid)
|
||||
op.mov.f32(consec_bad, 0., ifp=p_point_is_valid)
|
||||
|
||||
comment('Otherwise, add 1 to consec_bad')
|
||||
op.add.f32(consec, consec, 1., ifnotp=p_point_is_valid)
|
||||
op.setp.ge.f32(need_new_point, consec, 5.)
|
||||
op.add.f32(consec_bad, consec_bad, 1., ifnotp=p_valid_pt)
|
||||
op.setp.ge.f32(need_new_point, consec_bad, float(features.max_bad))
|
||||
op.bra('badval_done', ifnotp=need_new_point)
|
||||
|
||||
comment('If consec_bad > 5, pick a new random point')
|
||||
mwc.next_f32_11(xi)
|
||||
mwc.next_f32_11(yi)
|
||||
mwc.next_f32_01(colori)
|
||||
op.mov.f32(consec, 0.)
|
||||
mwc.next_f32_11(x)
|
||||
mwc.next_f32_11(y)
|
||||
mwc.next_f32_01(color)
|
||||
op.mov.f32(consec_bad, float(-features.fuse))
|
||||
label('badval_done')
|
||||
|
||||
with block("Increment number of samples by number of good values"):
|
||||
reg.b32('good_samples laneid')
|
||||
reg.pred('p_is_first')
|
||||
op.vote.ballot.b32(good_samples, p_point_is_valid)
|
||||
op.vote.ballot.b32(good_samples, p_valid_pt)
|
||||
op.popc.b32(good_samples, good_samples)
|
||||
op.mov.u32(laneid, '%laneid')
|
||||
op.setp.eq.u32(p_is_first, laneid, 0)
|
||||
@ -209,7 +184,7 @@ class IterThread(PTXEntryPoint):
|
||||
op.bra.uni(cp_loop_start, ifp=p_cp_done)
|
||||
|
||||
comment('Shuffle points between threads')
|
||||
shuf.shuffle(xi, yi, colori, consec_bad)
|
||||
shuf.shuffle(x, y, color, consec_bad)
|
||||
|
||||
with block("If in first warp, pick new offset"):
|
||||
reg.u32('tid')
|
||||
@ -390,14 +365,14 @@ class PaletteLookup(PTXFragment):
|
||||
mem.global_.texref('t_palette')
|
||||
|
||||
@ptx_func
|
||||
def look_up(self, r, g, b, a, color, norm_time):
|
||||
def look_up(self, r, g, b, a, color, norm_time, ifp):
|
||||
"""
|
||||
Look up the values of ``r, g, b, a`` corresponding to ``color_coord``
|
||||
at the CP indexed in ``timestamp_idx``. Note that both ``color_coord``
|
||||
and ``timestamp_idx`` should be [0,1]-normalized floats.
|
||||
"""
|
||||
op.tex._2d.v4.f32.f32(vec(r, g, b, a),
|
||||
addr([t_palette, ', ', vec(norm_time, color)]))
|
||||
addr([t_palette, ', ', vec(norm_time, color)]), ifp=ifp)
|
||||
if features.non_box_temporal_filter:
|
||||
raise NotImplementedError("Non-box temporal filters not supported")
|
||||
|
||||
@ -462,17 +437,13 @@ class HistScatter(PTXFragment):
|
||||
op.st.volatile.b32(addr(g_hist_dummy), hist_dummy)
|
||||
|
||||
@ptx_func
|
||||
def scatter(self, x, y, color, xf_idx, p_valid=None, type='ldst'):
|
||||
def scatter(self, hist_index, color, xf_idx, p_valid, type='ldst'):
|
||||
"""
|
||||
Scatter the given point directly to the histogram bins. I think this
|
||||
technique has the worst performance of all of 'em. Accesses ``cpA``
|
||||
directly.
|
||||
"""
|
||||
with block("Scatter directly to buffer"):
|
||||
if p_valid is None:
|
||||
p_valid = reg.pred('p_valid')
|
||||
reg.u32('hist_index')
|
||||
camera.get_index(hist_index, x, y, p_valid)
|
||||
reg.u32('hist_bin_addr')
|
||||
op.mov.u32(hist_bin_addr, g_hist_bins)
|
||||
op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr)
|
||||
@ -484,20 +455,22 @@ class HistScatter(PTXFragment):
|
||||
|
||||
reg.f32('r g b a norm_time')
|
||||
cp.get(cpA, norm_time, 'cp.norm_time')
|
||||
palette.look_up(r, g, b, a, color, norm_time)
|
||||
palette.look_up(r, g, b, a, color, norm_time, ifp=p_valid)
|
||||
# TODO: look up, scale by xform visibility
|
||||
# TODO: Make this more performant
|
||||
if type == 'ldst':
|
||||
reg.f32('gr gg gb ga')
|
||||
op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr))
|
||||
op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr),
|
||||
ifp=p_valid)
|
||||
op.add.f32(gr, gr, r)
|
||||
op.add.f32(gg, gg, g)
|
||||
op.add.f32(gb, gb, b)
|
||||
op.add.f32(ga, ga, a)
|
||||
op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga))
|
||||
op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga),
|
||||
ifp=p_valid)
|
||||
elif type == 'red':
|
||||
for i, val in enumerate([r, g, b, a]):
|
||||
op.red.add.f32(addr(hist_bin_addr,4*i), val)
|
||||
op.red.add.f32(addr(hist_bin_addr,4*i), val, ifp=p_valid)
|
||||
elif type == 'fake':
|
||||
op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr)
|
||||
op.st.local.f32(addr(l_scatter_fake_alpha), a)
|
||||
|
@ -213,8 +213,11 @@ class Features(object):
|
||||
Determine features and constants required to render a particular set of
|
||||
genomes. The values of this class are fixed before compilation begins.
|
||||
"""
|
||||
# Constant; number of rounds spent fusing points on first CP of a frame
|
||||
num_fuse_samples = 25
|
||||
# Constant parameters which control handling of out-of-frame samples:
|
||||
# Number of iterations to iterate without write after new point
|
||||
fuse = 2
|
||||
# Maximum consecutive out-of-frame points before picking new point
|
||||
max_bad = 3
|
||||
|
||||
def __init__(self, genomes, flt):
|
||||
any = lambda l: bool(filter(None, map(l, genomes)))
|
||||
|
@ -54,8 +54,8 @@ class Variations(PTXFragment):
|
||||
"""
|
||||
Apply a transform.
|
||||
|
||||
This function makes a copy of the input variables, so it's safe to use
|
||||
the same registers for input and output.
|
||||
This function necessarily makes a copy of the input variables, so it's
|
||||
safe to use the same registers for input and output.
|
||||
"""
|
||||
with block("Apply xform %d" % xform_idx):
|
||||
self.xform_idx = xform_idx
|
||||
@ -123,7 +123,7 @@ class Variations(PTXFragment):
|
||||
@ptx_func
|
||||
def spherical(self, xo, yo, xi, yi, wgt):
|
||||
reg.f32('r2')
|
||||
op.fma.rn.ftz.f32(r2, xi, xi, '1e-9')
|
||||
op.fma.rn.ftz.f32(r2, xi, xi, '1e-30')
|
||||
op.fma.rn.ftz.f32(r2, yi, yi, r2)
|
||||
op.rcp.approx.f32(r2, r2)
|
||||
op.mul.rn.ftz.f32(r2, r2, wgt)
|
||||
|
Loading…
Reference in New Issue
Block a user