Use variations. This works, but is still fragile.

This commit is contained in:
Steven Robertson 2010-09-11 13:15:36 -04:00
parent 860d7b2fad
commit a5d7c2cc1a
3 changed files with 150 additions and 31 deletions

View File

@ -38,6 +38,7 @@ class LaunchContext(object):
self.entry_types = entries self.entry_types = entries
self.block, self.grid, self.build_tests = block, grid, tests self.block, self.grid, self.build_tests = block, grid, tests
self.setup_done = False self.setup_done = False
self.stream = cuda.Stream()
@property @property
def threads(self): def threads(self):

View File

@ -10,16 +10,18 @@ import pycuda.driver as cuda
import numpy as np import numpy as np
from cuburn.ptx import * from cuburn.ptx import *
from cuburn.variations import Variations
class IterThread(PTXEntryPoint): class IterThread(PTXEntryPoint):
entry_name = 'iter_thread' entry_name = 'iter_thread'
entry_params = [] entry_params = []
maxnreg = 16
def __init__(self): def __init__(self):
self.cps_uploaded = False self.cps_uploaded = False
def deps(self): def deps(self):
return [MWCRNG, CPDataStream, HistScatter] return [MWCRNG, CPDataStream, HistScatter, Variations]
@ptx_func @ptx_func
def module_setup(self): def module_setup(self):
@ -30,24 +32,32 @@ class IterThread(PTXEntryPoint):
# TODO move into debug statement # TODO move into debug statement
mem.global_.u32('g_num_rounds', ctx.threads) mem.global_.u32('g_num_rounds', ctx.threads)
mem.global_.u32('g_num_writes', ctx.threads) mem.global_.u32('g_num_writes', ctx.threads)
mem.global_.b32('g_whatever', ctx.threads)
@ptx_func @ptx_func
def entry(self): def entry(self):
# For now, we indulge in the luxury of shared memory. # For now, we indulge in the luxury of shared memory.
# Index number of current CP, shared across CTA # Index number of current CP, shared across CTA
mem.shared.u32('s_cp_idx') mem.shared.u32('s_cp_idx')
# Number of samples that have been generated so far in this CTA # Number of samples that have been generated so far in this CTA
# If this number is negative, we're still fusing points, so this # If this number is negative, we're still fusing points, so this
# behaves slightly differently (see ``fuse_loop_start``) # behaves slightly differently (see ``fuse_loop_start``)
mem.shared.u32('s_num_samples') mem.shared.s32('s_num_samples')
op.st.shared.u32(addr(s_num_samples), -(features.num_fuse_samples+1)) op.st.shared.s32(addr(s_num_samples), -(features.num_fuse_samples+1))
mem.shared.f32('s_xf_sel', ctx.warps_per_cta)
std.store_per_thread(g_whatever, 1234)
# TODO: temporary, for testing # TODO: temporary, for testing
reg.u32('num_rounds num_writes') mem.local.u32('l_num_rounds')
op.mov.u32(num_rounds, 0) mem.local.u32('l_num_writes')
op.mov.u32(num_writes, 0) op.st.local.u32(addr(l_num_rounds), 0)
op.st.local.u32(addr(l_num_writes), 0)
mem.local.f32('l_consec')
op.st.local.f32(addr(l_consec), 0.)
reg.f32('x_coord y_coord color_coord') reg.f32('x_coord y_coord color_coord')
mwc.next_f32_11(x_coord) mwc.next_f32_11(x_coord)
@ -57,16 +67,24 @@ class IterThread(PTXEntryPoint):
comment("Ensure all init is done") comment("Ensure all init is done")
op.bar.sync(0) op.bar.sync(0)
label('cp_loop_start') label('cp_loop_start')
reg.u32('cp_idx cpA') reg.u32('cp_idx cpA')
with block("Claim a CP"): with block("Claim a CP"):
std.set_is_first_thread(reg.pred('p_is_first')) std.set_is_first_thread(reg.pred('p_is_first'))
op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first) op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first) op.st.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
op.st.shared.u32(addr(s_num_samples), 0, ifp=p_is_first)
with block("If done fusing, reset the sample count now"):
reg.pred("p_done_fusing")
reg.s32('num_samples')
op.ld.shared.s32(num_samples, addr(s_num_samples))
op.setp.gt.s32(p_done_fusing, num_samples, 0)
op.st.shared.s32(addr(s_num_samples), 0, ifp=p_done_fusing)
comment("Load the CP index in all threads") comment("Load the CP index in all threads")
op.bar.sync(1) op.bar.sync(0)
op.ld.shared.u32(cp_idx, addr(s_cp_idx)) op.ld.shared.u32(cp_idx, addr(s_cp_idx))
with block("Check to see if this CP is valid (if not, we're done)"): with block("Check to see if this CP is valid (if not, we're done)"):
@ -80,24 +98,68 @@ class IterThread(PTXEntryPoint):
op.mov.u32(cpA, g_cp_array) op.mov.u32(cpA, g_cp_array)
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA) op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
label('fuse_loop_start') label('fuse_loop_start')
# When fusing, num_samples holds the (negative) number of iterations # When fusing, num_samples holds the (negative) number of iterations
# left across the CP, rather than the number of samples in total. # left across the CP, rather than the number of samples in total.
with block("If still fusing, increment count unconditionally"): with block("If still fusing, increment count unconditionally"):
std.set_is_first_thread(reg.pred('p_is_first')) std.set_is_first_thread(reg.pred('p_is_first'))
op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first) op.red.shared.add.s32(addr(s_num_samples), 1, ifp=p_is_first)
op.bar.sync(2)
label('iter_loop_choose_xform')
with block("Choose the xform for each warp"):
comment("On subsequent runs, only warp 0 will hit this code")
reg.u32('x_addr x_offset')
reg.f32('xf_sel')
op.mov.u32(x_addr, s_xf_sel)
op.mov.u32(x_offset, '%tid.x')
op.and_.b32(x_offset, x_offset, ctx.warps_per_cta-1)
op.mad.lo.u32(x_addr, x_offset, 4, x_addr)
mwc.next_f32_01(xf_sel)
op.st.volatile.shared.f32(addr(x_addr), xf_sel)
label('iter_loop_start') label('iter_loop_start')
comment('Do... well, most of everything') comment("I really didn't want to have to sync each loop, but it seems")
comment("like the highest-performance strategy right now")
mwc.next_f32_11(x_coord) #op.bar.sync(1)
mwc.next_f32_11(y_coord)
mwc.next_f32_01(color_coord)
with block():
reg.u32('num_rounds')
reg.pred('overload')
op.ld.local.u32(num_rounds, addr(l_num_rounds))
op.add.u32(num_rounds, num_rounds, 1) op.add.u32(num_rounds, num_rounds, 1)
op.st.local.u32(addr(l_num_rounds), num_rounds)
with block("Select an xform"):
reg.f32('xf_sel')
reg.u32('warp_offset xf_sel_addr')
op.mov.u32(warp_offset, '%tid.x')
op.mov.u32(xf_sel_addr, s_xf_sel)
op.shr.u32(warp_offset, warp_offset, 5)
op.mad.lo.u32(xf_sel_addr, warp_offset, 4, xf_sel_addr)
op.ld.volatile.shared.f32(xf_sel, addr(xf_sel_addr))
reg.f32('xf_density')
reg.pred('xf_jump')
for xf in features.xforms:
cp.get(cpA, xf_density, 'cp.xforms[%d].cweight' % xf.id)
op.setp.le.f32(xf_jump, xf_sel, xf_density)
op.bra('XFORM_%d' % xf.id, ifp=xf_jump)
std.asrt("Reached end of xforms without choosing one")
for xf in features.xforms:
label('XFORM_%d' % xf.id)
variations.apply_xform(x_coord, y_coord, color_coord,
x_coord, y_coord, color_coord, xf.id)
op.bra.uni("xform_done")
label("xform_done")
with block("Test if we're still in FUSE"): with block("Test if we're still in FUSE"):
reg.s32('num_samples') reg.s32('num_samples')
reg.pred('p_in_fuse') reg.pred('p_in_fuse')
@ -108,7 +170,26 @@ class IterThread(PTXEntryPoint):
reg.pred('p_point_is_valid') reg.pred('p_point_is_valid')
with block("Write the result"): with block("Write the result"):
hist.scatter(x_coord, y_coord, color_coord, 0, p_point_is_valid) hist.scatter(x_coord, y_coord, color_coord, 0, p_point_is_valid)
with block():
reg.u32('num_writes')
op.ld.local.u32(num_writes, addr(l_num_writes))
op.add.u32(num_writes, num_writes, 1, ifp=p_point_is_valid) op.add.u32(num_writes, num_writes, 1, ifp=p_point_is_valid)
op.st.local.u32(addr(l_num_writes), num_writes)
with block("If the result was invalid, handle badvals"):
reg.f32('consec')
reg.pred('need_new_point')
op.ld.local.f32(consec, addr(l_consec))
op.mov.f32(consec, 0., ifp=p_point_is_valid)
op.add.f32(consec, consec, 1., ifnotp=p_point_is_valid)
op.setp.ge.f32(need_new_point, consec, 5.)
op.bra('badval_done', ifnotp=need_new_point)
mwc.next_f32_11(x_coord)
mwc.next_f32_11(y_coord)
mwc.next_f32_01(color_coord)
op.mov.f32(consec, 0.)
label('badval_done')
op.st.local.f32(addr(l_consec), consec)
with block("Increment number of samples by number of good values"): with block("Increment number of samples by number of good values"):
reg.b32('good_samples laneid') reg.b32('good_samples laneid')
@ -125,13 +206,27 @@ class IterThread(PTXEntryPoint):
reg.s32('num_samples num_samples_needed') reg.s32('num_samples num_samples_needed')
op.ld.shared.s32(num_samples, addr(s_num_samples)) op.ld.shared.s32(num_samples, addr(s_num_samples))
cp.get(cpA, num_samples_needed, 'cp.nsamples') cp.get(cpA, num_samples_needed, 'cp.nsamples')
std.store_per_thread(g_whatever, num_samples_needed)
op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed) op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
op.bra.uni(cp_loop_start, ifp=p_cp_done) op.bra.uni(cp_loop_start, ifp=p_cp_done)
with block("If first warp, pick new thread offset"):
reg.u32('warpid')
reg.pred('first_warp')
op.mov.u32(warpid, '%tid.x')
op.shr.b32(warpid, warpid, 5)
op.setp.eq.u32(first_warp, warpid, 0)
#std.asrt("Looks like we're not the first warp", notp=first_warp,
#ret=True)
op.bra.uni(iter_loop_choose_xform, ifp=first_warp)
op.bra.uni(iter_loop_start) op.bra.uni(iter_loop_start)
label('all_cps_done') label('all_cps_done')
# TODO this is for testing, move it to a debug statement # TODO this is for testing, move it to a debug statement
with block():
reg.u32('num_rounds num_writes')
op.ld.local.u32(num_rounds, addr(l_num_rounds))
op.ld.local.u32(num_writes, addr(l_num_writes))
std.store_per_thread(g_num_rounds, num_rounds, std.store_per_thread(g_num_rounds, num_rounds,
g_num_writes, num_writes) g_num_writes, num_writes)
@ -139,7 +234,7 @@ class IterThread(PTXEntryPoint):
def upload_cp_stream(self, ctx, cp_stream, num_cps): def upload_cp_stream(self, ctx, cp_stream, num_cps):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array') cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
assert len(cp_stream) <= cp_array_l, "Stream too big!" assert len(cp_stream) <= cp_array_l, "Stream too big!"
cuda.memcpy_htod_async(cp_array_dp, cp_stream) cuda.memcpy_htod(cp_array_dp, cp_stream)
num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps') num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps')
cuda.memset_d32(num_cps_dp, num_cps, 1) cuda.memset_d32(num_cps_dp, num_cps, 1)
@ -162,14 +257,29 @@ class IterThread(PTXEntryPoint):
def call_teardown(self, ctx): def call_teardown(self, ctx):
shape = (ctx.grid[0], ctx.block[0]/32, 32) shape = (ctx.grid[0], ctx.block[0]/32, 32)
def print_thing(s, a):
print '%s:' % s
for i, r in enumerate(a):
for j in range(0,len(r),8):
print '%2d\t%s' % (i,
'\t'.join(['%g '%np.mean(r[k]) for k in range(j,j+8)]))
num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds') num_rounds_dp, num_rounds_l = ctx.mod.get_global('g_num_rounds')
num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes') num_writes_dp, num_writes_l = ctx.mod.get_global('g_num_writes')
whatever_dp, whatever_l = ctx.mod.get_global('g_whatever')
rounds = cuda.from_device(num_rounds_dp, shape, np.int32) rounds = cuda.from_device(num_rounds_dp, shape, np.int32)
writes = cuda.from_device(num_writes_dp, shape, np.int32) writes = cuda.from_device(num_writes_dp, shape, np.int32)
print "Rounds:", sum(rounds) whatever = cuda.from_device(whatever_dp, shape, np.int32)
print "Writes:", sum(writes) print_thing("Rounds", rounds)
print rounds print_thing("Writes", writes)
print writes print_thing("Whatever", whatever)
print np.sum(rounds)
dp, l = ctx.mod.get_global('g_num_cps_started')
cps_started = cuda.from_device(dp, 1, np.uint32)
print "CPs started:", cps_started
class CameraTransform(PTXFragment): class CameraTransform(PTXFragment):
shortname = 'camera' shortname = 'camera'
@ -363,10 +473,13 @@ class HistScatter(PTXFragment):
cp.get(cpA, norm_time, 'cp.norm_time') cp.get(cpA, norm_time, 'cp.norm_time')
palette.look_up(r, g, b, a, color, norm_time) palette.look_up(r, g, b, a, color, norm_time)
# TODO: look up, scale by xform visibility # TODO: look up, scale by xform visibility
op.red.add.f32(addr(hist_bin_addr), r) # TODO: Make this more performant
op.red.add.f32(addr(hist_bin_addr,4), g) reg.f32('gval')
op.red.add.f32(addr(hist_bin_addr,8), b) for i, val in enumerate([r, g, b, a]):
op.red.add.f32(addr(hist_bin_addr,12), a) #op.red.add.f32(addr(hist_bin_addr,4*i), val)
op.ld.f32(gval,addr(hist_bin_addr,4*i))
op.add.f32(gval, gval, val)
op.st.f32(addr(hist_bin_addr,4*i),gval)
def call_setup(self, ctx): def call_setup(self, ctx):
@ -380,6 +493,8 @@ class HistScatter(PTXFragment):
(features.hist_height, features.hist_stride, 4), (features.hist_height, features.hist_stride, 4),
dtype=np.float32) dtype=np.float32)
class MWCRNG(PTXFragment): class MWCRNG(PTXFragment):
shortname = "mwc" shortname = "mwc"

11
main.py
View File

@ -11,6 +11,7 @@
import os import os
import sys import sys
from pprint import pprint
from ctypes import * from ctypes import *
import numpy as np import numpy as np
@ -39,10 +40,12 @@ def main(args):
anim = Animation(genomes) anim = Animation(genomes)
anim.compile() anim.compile()
bins = anim.render_frame() bins = anim.render_frame()
#dump_3d(bins) #bins = np.log2(bins + 1)
bins /= ((np.mean(bins)+1e-9)/128.) bins *= (512./(np.mean([bins[y][x][3]
bins.astype(np.uint8) for x in range(anim.features.hist_width)
for y in range(anim.features.hist_height)])+1e-9))
bins = np.minimum(bins, 255)
bins = bins.astype(np.uint8)
if '-g' not in args: if '-g' not in args:
return return