cuburn/cuburn/device_code.py

759 lines
30 KiB
Python
Raw Normal View History

2010-09-01 13:02:12 -04:00
"""
Contains the PTX fragments which will drive the device.
"""
2010-08-28 16:56:05 -04:00
import os
import time
2010-09-06 11:18:20 -04:00
import struct
2010-08-28 16:56:05 -04:00
import pycuda.driver as cuda
import numpy as np
2010-10-07 11:21:43 -04:00
from pyptx import ptx, run
from cuburn.variations import Variations
2010-08-28 16:56:05 -04:00
2010-10-07 11:21:43 -04:00
class IterThread(object):
entry_name = 'iter_thread'
entry_params = []
2010-09-06 11:18:20 -04:00
def __init__(self):
self.cps_uploaded = False
def deps(self):
return [MWCRNG, CPDataStream, HistScatter, Variations, ShufflePoints,
Timeouter]
def module_setup(self):
mem.global_.u32('g_cp_array',
cp.stream_size*features.max_ntemporal_samples)
mem.global_.u32('g_num_cps')
mem.global_.u32('g_num_cps_started')
# TODO move into debug statement
mem.global_.u32('g_num_rounds', ctx.nthreads)
mem.global_.u32('g_num_writes', ctx.nthreads)
mem.global_.b32('g_whatever', ctx.nthreads)
2010-09-06 11:18:20 -04:00
def entry(self):
# Index number of current CP, shared across CTA
mem.shared.u32('s_cp_idx')
# Number of samples that have been generated so far in this CTA
# If this number is negative, we're still fusing points, so this
# behaves slightly differently (see ``fuse_loop_start``)
# TODO: replace (or at least simplify) this logic
mem.shared.s32('s_num_samples')
mem.shared.f32('s_xf_sel', ctx.warps_per_cta)
# TODO: temporary, for testing
mem.local.u32('l_num_rounds')
mem.local.u32('l_num_writes')
op.st.local.u32(addr(l_num_rounds), 0)
op.st.local.u32(addr(l_num_writes), 0)
reg.f32('x y color consec_bad')
mwc.next_f32_11(x)
mwc.next_f32_11(y)
mwc.next_f32_01(color)
op.mov.f32(consec_bad, float(-features.fuse))
comment("Ensure all init is done")
op.bar.sync(0)
label('cp_loop_start')
reg.u32('cp_idx cpA')
with block("Claim a CP"):
std.set_is_first_thread(reg.pred('p_is_first'))
2010-09-09 11:36:14 -04:00
op.atom.add.u32(cp_idx, addr(g_num_cps_started), 1, ifp=p_is_first)
2010-09-13 00:20:15 -04:00
op.st.volatile.shared.u32(addr(s_cp_idx), cp_idx, ifp=p_is_first)
op.st.volatile.shared.s32(addr(s_num_samples), 0)
comment("Load the CP index in all threads")
op.bar.sync(0)
2010-09-13 00:20:15 -04:00
op.ld.volatile.shared.u32(cp_idx, addr(s_cp_idx))
2010-09-09 11:36:14 -04:00
with block("Check to see if this CP is valid (if not, we're done)"):
reg.u32('num_cps')
reg.pred('p_last_cp')
op.ldu.u32(num_cps, addr(g_num_cps))
2010-09-09 11:36:14 -04:00
op.setp.ge.u32(p_last_cp, cp_idx, num_cps)
op.bra('all_cps_done', ifp=p_last_cp)
with block('Load CP address'):
op.mov.u32(cpA, g_cp_array)
op.mad.lo.u32(cpA, cp_idx, cp.stream_size, cpA)
label('iter_loop_choose_xform')
with block("Choose the xform for each warp"):
2010-09-12 17:38:51 -04:00
timeout.check_time(5)
comment("On subsequent runs, only warp 0 will hit this code")
reg.u32('x_addr x_offset')
reg.f32('xf_sel')
op.mov.u32(x_addr, s_xf_sel)
op.mov.u32(x_offset, '%tid.x')
op.and_.b32(x_offset, x_offset, ctx.warps_per_cta-1)
op.mad.lo.u32(x_addr, x_offset, 4, x_addr)
mwc.next_f32_01(xf_sel)
op.st.volatile.shared.f32(addr(x_addr), xf_sel)
label('iter_loop_start')
#timeout.check_time(10)
with block():
reg.u32('num_rounds')
reg.pred('overload')
op.ld.local.u32(num_rounds, addr(l_num_rounds))
op.add.u32(num_rounds, num_rounds, 1)
op.st.local.u32(addr(l_num_rounds), num_rounds)
with block("Select an xform"):
reg.f32('xf_sel')
reg.u32('warp_offset xf_sel_addr')
op.mov.u32(warp_offset, '%tid.x')
op.mov.u32(xf_sel_addr, s_xf_sel)
op.shr.u32(warp_offset, warp_offset, 5)
op.mad.lo.u32(xf_sel_addr, warp_offset, 4, xf_sel_addr)
op.ld.volatile.shared.f32(xf_sel, addr(xf_sel_addr))
reg.f32('xf_density')
reg.pred('xf_jump')
for xf in features.xforms:
cp.get(cpA, xf_density, 'cp.xforms[%d].cweight' % xf.id)
op.setp.le.f32(xf_jump, xf_sel, xf_density)
op.bra('XFORM_%d' % xf.id, ifp=xf_jump)
std.asrt("Reached end of xforms without choosing one")
for xf in features.xforms:
label('XFORM_%d' % xf.id)
variations.apply_xform(x, y, color, x, y, color, xf.id)
op.bra("xform_done")
2010-09-09 11:36:14 -04:00
label("xform_done")
reg.pred('p_valid_pt')
2010-09-09 11:36:14 -04:00
with block("Write the result"):
reg.u32('hist_index')
camera.get_index(hist_index, x, y, p_valid_pt)
comment('if consec_bad < 0, point is fusing; treat as invalid')
op.setp.and_.ge.f32(p_valid_pt, consec_bad, 0., p_valid_pt)
# TODO: save and pass correct xform value here
hist.scatter(hist_index, color, 0, p_valid_pt, 'ldst')
with block():
reg.u32('num_writes')
op.ld.local.u32(num_writes, addr(l_num_writes))
op.add.u32(num_writes, num_writes, 1, ifp=p_valid_pt)
op.st.local.u32(addr(l_num_writes), num_writes)
with block("If the result was invalid, handle badvals"):
reg.pred('need_new_point')
op.add.f32(consec_bad, consec_bad, 1., ifnotp=p_valid_pt)
op.setp.ge.f32(need_new_point, consec_bad, float(features.max_bad))
op.bra('badval_done', ifnotp=need_new_point)
comment('If consec_bad > 5, pick a new random point')
mwc.next_f32_11(x)
mwc.next_f32_11(y)
mwc.next_f32_01(color)
op.mov.f32(consec_bad, float(-features.fuse))
label('badval_done')
with block("Increment number of samples by number of good values"):
2010-09-09 11:36:14 -04:00
reg.b32('good_samples laneid')
reg.pred('p_is_first')
op.vote.ballot.b32(good_samples, p_valid_pt)
op.popc.b32(good_samples, good_samples)
2010-09-09 11:36:14 -04:00
op.mov.u32(laneid, '%laneid')
op.setp.eq.u32(p_is_first, laneid, 0)
op.red.shared.add.s32(addr(s_num_samples), good_samples,
ifp=p_is_first)
with block("Check to see if we're done with this CP"):
reg.pred('p_cp_done')
reg.s32('num_samples num_samples_needed')
comment('Sync before making decision to prevent divergence')
op.bar.sync(3)
2010-09-13 00:20:15 -04:00
op.ld.volatile.shared.s32(num_samples, addr(s_num_samples))
cp.get(cpA, num_samples_needed, 'cp.nsamples')
op.setp.ge.s32(p_cp_done, num_samples, num_samples_needed)
op.bra.uni(cp_loop_start, ifp=p_cp_done)
comment('Shuffle points between threads')
shuf.shuffle(x, y, color, consec_bad)
with block("If in first warp, pick new offset"):
reg.u32('tid')
reg.pred('first_warp')
op.mov.u32(tid, '%tid.x')
assert ctx.warps_per_cta <= 32, \
"Special-case for CTAs with >1024 threads not implemented"
op.setp.lo.u32(first_warp, tid, 32)
op.bra(iter_loop_choose_xform, ifp=first_warp)
op.bra(iter_loop_start)
2010-09-06 11:18:20 -04:00
label('all_cps_done')
# TODO this is for testing, move it to a debug statement
with block():
reg.u32('num_rounds num_writes')
op.ld.local.u32(num_rounds, addr(l_num_rounds))
op.ld.local.u32(num_writes, addr(l_num_writes))
std.store_per_thread(g_num_rounds, num_rounds,
g_num_writes, num_writes)
2010-09-06 11:18:20 -04:00
def upload_cp_stream(self, ctx, cp_stream, num_cps):
cp_array_dp, cp_array_l = ctx.mod.get_global('g_cp_array')
assert len(cp_stream) <= cp_array_l, "Stream too big!"
cuda.memcpy_htod(cp_array_dp, cp_stream)
2010-09-06 11:18:20 -04:00
num_cps_dp, num_cps_l = ctx.mod.get_global('g_num_cps')
cuda.memset_d32(num_cps_dp, num_cps, 1)
2010-09-09 11:36:14 -04:00
# TODO: "if debug >= 3"
print "Uploaded stream to card:"
CPDataStream.print_record(ctx, cp_stream, 5)
2010-09-06 11:18:20 -04:00
self.cps_uploaded = True
2010-09-10 14:43:20 -04:00
def call_setup(self, ctx):
2010-09-06 11:18:20 -04:00
if not self.cps_uploaded:
raise Error("Cannot call IterThread before uploading CPs")
num_cps_st_dp, num_cps_st_l = ctx.mod.get_global('g_num_cps_started')
cuda.memset_d32(num_cps_st_dp, 0, 1)
2010-09-10 14:43:20 -04:00
def _call(self, ctx, func):
# Get texture reference from the Palette
# TODO: more elegant method than reaching into ctx.ptx?
2010-09-09 11:36:14 -04:00
tr = ctx.ptx.instances[PaletteLookup].texref
2010-09-10 14:43:20 -04:00
super(IterThread, self)._call(ctx, func, texrefs=[tr])
2010-09-10 14:43:20 -04:00
def call_teardown(self, ctx):
def print_thing(s, a):
print '%s:' % s
for i, r in enumerate(a):
2010-09-12 13:45:55 -04:00
for j in range(0,len(r),ctx.warps_per_cta):
print '%2d' % i,
2010-09-12 13:45:55 -04:00
for k in range(j,j+ctx.warps_per_cta,8):
print '\t' + ' '.join(
['%8g'%np.mean(r[l]) for l in range(k,k+8)])
2010-09-12 13:45:55 -04:00
rounds = ctx.get_per_thread('g_num_rounds', np.int32, shaped=True)
writes = ctx.get_per_thread('g_num_writes', np.int32, shaped=True)
print_thing("Rounds", rounds)
print_thing("Writes", writes)
2010-09-12 13:45:55 -04:00
print "Total number of rounds:", np.sum(rounds)
dp, l = ctx.mod.get_global('g_num_cps_started')
cps_started = cuda.from_device(dp, 1, np.uint32)
print "CPs started:", cps_started
2010-09-09 11:36:14 -04:00
2010-10-07 11:21:43 -04:00
class CameraTransform(object):
2010-09-09 11:36:14 -04:00
shortname = 'camera'
def deps(self):
return [CPDataStream]
def rotate(self, rotated_x, rotated_y, x, y):
"""
Rotate an IFS-space coordinate as defined by the camera.
"""
if features.camera_rotation:
assert rotated_x.name != x.name and rotated_y.name != y.name
with block("Rotate %s, %s to camera alignment" % (x, y)):
reg.f32('rot_center_x rot_center_y')
cp.get_v2(cpA, rot_center_x, 'cp.rot_center[0]',
rot_center_y, 'cp.rot_center[1]')
op.sub.f32(x, x, rot_center_x)
op.sub.f32(y, y, rot_center_y)
reg.f32('rot_sin_t rot_cos_t rot_old_x rot_old_y')
cp.get_v2(cpA, rot_cos_t, 'cos(cp.rotate * 2 * pi / 360.)',
rot_sin_t, '-sin(cp.rotate * 2 * pi / 360.)')
comment('rotated_x = x * cos(t) - y * sin(t) + rot_center_x')
op.fma.rn.f32(rotated_x, x, rot_cos_t, rot_center_x)
op.fma.rn.f32(rotated_x, y, rot_sin_t, rotated_x)
op.neg.f32(rot_sin_t, rot_sin_t)
comment('rotated_y = x * sin(t) + y * cos(t) + rot_center_y')
op.fma.rn.f32(rotated_y, x, rot_sin_t, rot_center_y)
op.fma.rn.f32(rotated_y, y, rot_cos_t, rotated_y)
# TODO: if this is a register-critical section, reloading
# rot_center_[xy] here should save two regs. OTOH, if this is
# *not* reg-crit, moving the subtraction above to new variables
# may save a few clocks
op.add.f32(x, x, rot_center_x)
op.add.f32(y, y, rot_center_y)
else:
comment("No camera rotation in this kernel")
op.mov.f32(rotated_x, x)
op.mov.f32(rotated_y, y)
def get_norm(self, norm_x, norm_y, x, y):
"""
Find the [0,1]-normalized floating-point histogram coordinates
``norm_x, norm_y`` from the given IFS-space coordinates ``x, y``.
"""
self.rotate(norm_x, norm_y, x, y)
with block("Scale rotated points to [0,1]-normalized coordinates"):
reg.f32('cam_scale cam_offset')
cp.get_v2(cpA, cam_scale, 'cp.camera.norm_scale[0]',
cam_offset, 'cp.camera.norm_offset[0]')
op.fma.f32(norm_x, norm_x, cam_scale, cam_offset)
cp.get_v2(cpA, cam_scale, 'cp.camera.norm_scale[1]',
cam_offset, 'cp.camera.norm_offset[1]')
op.fma.f32(norm_y, norm_y, cam_scale, cam_offset)
def get_index(self, index, x, y, pred=None):
"""
Find the histogram index (as a u32) from the IFS spatial coordinate in
``x, y``.
If the coordinates are out of bounds, 0xffffffff will be stored to
``index``. If ``pred`` is given, it will be set if the point is valid,
and cleared if not.
"""
# A few instructions could probably be shaved off of this one
with block("Find histogram index"):
reg.f32('norm_x norm_y')
self.rotate(norm_x, norm_y, x, y)
comment('Scale and offset from IFS to index coordinates')
reg.f32('cam_scale cam_offset')
cp.get_v2(cpA, cam_scale, 'cp.camera.idx_scale[0]',
cam_offset, 'cp.camera.idx_offset[0]')
op.fma.rn.f32(norm_x, norm_x, cam_scale, cam_offset)
cp.get_v2(cpA, cam_scale, 'cp.camera.idx_scale[1]',
cam_offset, 'cp.camera.idx_offset[1]')
op.fma.rn.f32(norm_y, norm_y, cam_scale, cam_offset)
comment('Check for bad value')
reg.u32('index_x index_y')
if not pred:
pred = reg.pred('p_valid')
op.cvt.rzi.s32.f32(index_x, norm_x)
op.setp.ge.s32(pred, index_x, 0)
op.setp.lt.and_.s32(pred, index_x, features.hist_width, pred)
op.cvt.rzi.s32.f32(index_y, norm_y)
op.setp.ge.and_.s32(pred, index_y, 0, pred)
op.setp.lt.and_.s32(pred, index_y, features.hist_height, pred)
op.mad.lo.u32(index, index_y, features.hist_stride, index_x)
op.mov.u32(index, 0xffffffff, ifnotp=pred)
2010-10-07 11:21:43 -04:00
class PaletteLookup(object):
2010-09-09 11:36:14 -04:00
shortname = "palette"
# Resolution of texture on device. Bigger = more palette rez, maybe slower
texheight = 16
def __init__(self):
self.texref = None
def deps(self):
return [CPDataStream]
def module_setup(self):
mem.global_.texref('t_palette')
def look_up(self, r, g, b, a, color, norm_time, ifp):
2010-09-09 11:36:14 -04:00
"""
Look up the values of ``r, g, b, a`` corresponding to ``color_coord``
at the CP indexed in ``timestamp_idx``. Note that both ``color_coord``
and ``timestamp_idx`` should be [0,1]-normalized floats.
"""
op.tex._2d.v4.f32.f32(vec(r, g, b, a),
addr([t_palette, ', ', vec(norm_time, color)]), ifp=ifp)
2010-09-09 11:36:14 -04:00
if features.non_box_temporal_filter:
raise NotImplementedError("Non-box temporal filters not supported")
def upload_palette(self, ctx, frame, cp_list):
"""
Extract the palette from the given list of interpolated CPs, and upload
it to the device as a texture.
"""
# TODO: figure out if storing the full list is an actual drag on
# performance/memory
if frame.center_cp.temporal_filter_type != 0:
# TODO: make texture sample based on time, not on CP index
raise NotImplementedError("Use box temporal filters for now")
pal = np.ndarray((self.texheight, 256, 4), dtype=np.float32)
inv = float(len(cp_list) - 1) / (self.texheight - 1)
for y in range(self.texheight):
for x in range(256):
for c in range(4):
# TODO: interpolate here?
cy = int(round(y * inv))
pal[y][x][c] = cp_list[cy].palette.entries[x].color[c]
dev_array = cuda.make_multichannel_2d_array(pal, "C")
self.texref = ctx.mod.get_texref('t_palette')
# TODO: float16? or can we still use interp with int storage?
self.texref.set_format(cuda.array_format.FLOAT, 4)
self.texref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
self.texref.set_filter_mode(cuda.filter_mode.LINEAR)
self.texref.set_address_mode(0, cuda.address_mode.CLAMP)
self.texref.set_address_mode(1, cuda.address_mode.CLAMP)
self.texref.set_array(dev_array)
self.pal = pal
2010-09-09 11:36:14 -04:00
2010-09-10 14:43:20 -04:00
def call_setup(self, ctx):
2010-09-09 11:36:14 -04:00
assert self.texref, "Must upload palette texture before launch!"
2010-10-07 11:21:43 -04:00
class HistScatter(object):
2010-09-09 11:36:14 -04:00
shortname = "hist"
def deps(self):
return [CPDataStream, CameraTransform, PaletteLookup]
def module_setup(self):
mem.global_.f32('g_hist_bins',
features.hist_height * features.hist_stride * 4)
comment("Target to ensure fake local values get written")
mem.global_.f32('g_hist_dummy')
2010-09-09 11:36:14 -04:00
def entry_setup(self):
comment("Fake bins for fake scatter")
mem.local.f32('l_scatter_fake_adr')
mem.local.f32('l_scatter_fake_alpha')
def entry_teardown(self):
with block("Store fake histogram bins to dummy global"):
reg.b32('hist_dummy')
op.ld.local.b32(hist_dummy, addr(l_scatter_fake_adr))
op.st.volatile.b32(addr(g_hist_dummy), hist_dummy)
op.ld.local.b32(hist_dummy, addr(l_scatter_fake_alpha))
op.st.volatile.b32(addr(g_hist_dummy), hist_dummy)
2010-09-09 11:36:14 -04:00
def scatter(self, hist_index, color, xf_idx, p_valid, type='ldst'):
2010-09-09 11:36:14 -04:00
"""
Scatter the given point directly to the histogram bins. I think this
technique has the worst performance of all of 'em. Accesses ``cpA``
directly.
"""
with block("Scatter directly to buffer"):
reg.u32('hist_bin_addr')
op.mov.u32(hist_bin_addr, g_hist_bins)
op.mad.lo.u32(hist_bin_addr, hist_index, 16, hist_bin_addr)
if type == 'fake_notex':
op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr)
op.st.local.f32(addr(l_scatter_fake_alpha), color)
return
2010-09-09 11:36:14 -04:00
reg.f32('r g b a norm_time')
cp.get(cpA, norm_time, 'cp.norm_time')
palette.look_up(r, g, b, a, color, norm_time, ifp=p_valid)
2010-09-09 11:36:14 -04:00
# TODO: look up, scale by xform visibility
# TODO: Make this more performant
if type == 'ldst':
reg.f32('gr gg gb ga')
op.ld.v4.f32(vec(gr, gg, gb, ga), addr(hist_bin_addr),
ifp=p_valid)
op.add.f32(gr, gr, r)
op.add.f32(gg, gg, g)
op.add.f32(gb, gb, b)
op.add.f32(ga, ga, a)
op.st.v4.f32(addr(hist_bin_addr), vec(gr, gg, gb, ga),
ifp=p_valid)
elif type == 'red':
for i, val in enumerate([r, g, b, a]):
op.red.add.f32(addr(hist_bin_addr,4*i), val, ifp=p_valid)
elif type == 'fake':
op.st.local.u32(addr(l_scatter_fake_adr), hist_bin_addr)
op.st.local.f32(addr(l_scatter_fake_alpha), a)
2010-09-09 11:36:14 -04:00
2010-09-10 14:43:20 -04:00
def call_setup(self, ctx):
2010-09-09 11:36:14 -04:00
hist_bins_dp, hist_bins_l = ctx.mod.get_global('g_hist_bins')
cuda.memset_d32(hist_bins_dp, 0, hist_bins_l/4)
def get_bins(self, ctx, features):
hist_bins_dp, hist_bins_l = ctx.mod.get_global('g_hist_bins')
return cuda.from_device(hist_bins_dp,
(features.hist_height, features.hist_stride, 4),
dtype=np.float32)
2010-10-07 11:21:43 -04:00
class ShufflePoints(object):
"""
Shuffle points in shared memory. See helpers/shuf.py for details.
"""
shortname = "shuf"
def module_setup(self):
# TODO: if needed, merge this shared memory block with others
mem.shared.f32('s_shuf_data', ctx.threads_per_cta)
def shuffle(self, *args, **kwargs):
"""
Shuffle the data from each register in args across threads. Keyword
argument ``bar`` specifies which barrier to use (default is 2).
"""
bar = kwargs.pop('bar', 2)
with block("Shuffle across threads"):
reg.u32('shuf_read shuf_write')
with block("Calculate read and write offsets"):
reg.u32('shuf_off shuf_laneid')
op.mov.u32(shuf_off, '%tid.x')
op.mov.u32(shuf_write, s_shuf_data)
op.mad.lo.u32(shuf_write, shuf_off, 4, shuf_write)
op.mov.u32(shuf_laneid, '%laneid')
op.mad.lo.u32(shuf_off, shuf_laneid, 32, shuf_off)
op.and_.b32(shuf_off, shuf_off, ctx.threads_per_cta - 1)
op.mov.u32(shuf_read, s_shuf_data)
op.mad.lo.u32(shuf_read, shuf_off, 4, shuf_read)
for var in args:
op.bar.sync(bar)
2010-09-13 00:20:15 -04:00
op.st.volatile.shared.b32(addr(shuf_write), var)
op.bar.sync(bar)
2010-09-13 00:20:15 -04:00
op.ld.volatile.shared.b32(var, addr(shuf_read))
2010-10-07 11:21:43 -04:00
class MWCRNG(object):
2010-10-07 11:21:43 -04:00
"""
Marsaglia multiply-with-carry random number generator. Produces very long
periods with sufficient statistical properties using only three 32-bit
state registers. Since each thread uses a separate multiplier, no two
threads will ever be on the same sequence, but beyond this the independence
of each thread's sequence was not explicitly tested.
The RNG must be seeded at least once per entry point using the ``seed``
method.
"""
def __init__(self, entry):
# TODO: install this in data directory or something
2010-08-28 16:56:05 -04:00
if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found')
2010-10-07 11:21:43 -04:00
self.nthreads_ready = 0
self.mults, self.state = None, None
2010-10-07 11:21:43 -04:00
entry.add_ptr_param('mwc_mults', 'u32')
entry.add_ptr_param('mwc_states', 'u32')
with entry.head():
self.entry_head(entry)
entry.tail_callback(self.entry_tail, entry)
def entry_head(self, entry):
e, r, o, m, p, s = entry.locals
gtid = s.ctaid_x * s.ntid_x + s.tid_x
r.mwc_mult, r.mwc_state, r.mwc_carry = r.u32(), r.u32(), r.u32()
r.mwc_mult = o.ld(p.mwc_mults[gtid])
r.mwc_state, r.mwc_carry = o.ld.v2(p.mwc_states[2*gtid])
def entry_tail(self, entry):
e, r, o, m, p, s = entry.locals
gtid = s.ctaid_x * s.ntid_x + s.tid_x
o.st.v2.u32(p.mwc_states[2*gtid], r.mwc_state, r.mwc_carry)
def next_b32(self, entry):
e, r, o, m, p, s = entry.locals
carry = o.cvt.u64(r.mwc_carry)
mwc_out = o.mad.wide(r.mwc_mult, r.mwc_state, carry)
r.mwc_state, r.mwc_carry = o.split.v2(mwc_out)
return r.mwc_state
2010-10-07 11:21:43 -04:00
def next_f32_01(self, entry):
e, r, o, m, p, s = entry.locals
mwc_float = o.cvt.rn.f32.u32(self.next_b32())
return o.mul.f32(mwc_float, 1./(1<<32))
2010-10-07 11:21:43 -04:00
def next_f32_11(self, entry):
e, r, o, m, p, s = entry.locals
mwc_float = o.cvt.rn.f32.s32(self.next_b32())
return o.mul.f32(mwc_float, 1./(1<<31))
2010-10-07 11:21:43 -04:00
def seed(self, ctx, seed=None, force=False):
2010-09-10 14:43:20 -04:00
"""
Seed the random number generators with values taken from a
``np.random`` instance.
"""
if force or self.nthreads_ready < ctx.nthreads:
2010-10-07 11:21:43 -04:00
if seed:
rand = np.random.RandomState(seed)
else:
rand = np.random
# Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B')
mults = np.frombuffer(primefp.read(), dtype=dt)
# Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise.
mults = np.array(mults[:ctx.nthreads*4])
rand.shuffle(mults)
2010-10-07 11:21:43 -04:00
#locked_mults = ctx.hostpool.allocate(ctx.nthreads, np.uint32)
#locked_mults[:] = mults[ctx.nthreads]
#self.mults = ctx.pool.allocate(4*ctx.nthreads)
#cuda.memcpy_htod_async(self.mults, locked_mults.base, ctx.stream)
self.mults = cuda.mem_alloc(4*ctx.nthreads)
cuda.memcpy_htod(self.mults, mults[:ctx.nthreads].tostring())
# Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0
states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.nthreads),
dtype=np.uint32)
2010-10-07 11:21:43 -04:00
#locked_states = ctx.hostpool.allocate(2*ctx.nthreads, np.uint32)
#locked_states[:] = states
#self.states = ctx.pool.allocate(8*ctx.nthreads)
#cuda.memcpy_htod_async(self.states, locked_states, ctx.stream)
self.states = cuda.mem_alloc(8*ctx.nthreads)
cuda.memcpy_htod(self.states, states.tostring())
self.nthreads_ready = ctx.nthreads
ctx.set_param('mwc_mults', self.mults)
ctx.set_param('mwc_states', self.states)
2010-10-07 11:21:43 -04:00
class MWCRNGTest(object):
"""
Test the ``MWCRNG`` class. This is not a test of the generator's
statistical properties, but merely a test that the generator is implemented
correctly on the GPU.
"""
2010-09-01 22:46:55 -04:00
rounds = 5000
def __init__(self, entry):
self.mwc = MWCRNG(entry)
2010-10-07 11:21:43 -04:00
entry.add_ptr_param('mwc_test_sums', 'u64')
with entry.body():
2010-10-07 11:21:43 -04:00
self.entry_body(entry)
def entry_body(self, entry):
e, r, o, m, p, s = entry.locals
r.sum = r.u64(0)
r.count = r.f32(self.rounds)
start = e.label()
r.sum = r.sum + o.cvt.u64.u32(self.mwc.next_b32(e))
r.count = r.count - 1
with r.count > 0:
o.bra.uni(start)
e.comment('yay')
gtid = s.ctaid_x * s.ntid_x + s.tid_x
o.st(p.mwc_test_sums[gtid], r.sum)
def run_test(self, ctx):
self.mwc.seed(ctx)
mults = cuda.from_device(self.mwc.mults, ctx.nthreads, np.uint32)
states = cuda.from_device(self.mwc.states, ctx.nthreads, np.uint64)
for trial in range(2):
print "Trial %d, on CPU: " % trial,
2010-10-07 11:21:43 -04:00
sums = np.zeros_like(states)
ctime = time.time()
for i in range(self.rounds):
2010-10-07 11:21:43 -04:00
vals = states & 0xffffffff
carries = states >> 32
states = mults * vals + carries
sums += states & 0xffffffff
ctime = time.time() - ctime
print "Took %g seconds." % ctime
print "Trial %d, on device: " % trial,
2010-10-07 11:21:43 -04:00
dsums = cuda.mem_alloc(8*ctx.nthreads)
ctx.set_param('mwc_test_sums', dsums)
print "Took %g seconds." % ctx.call_timed()
print ctx.nthreads
dsums = cuda.from_device(dsums, ctx.nthreads, np.uint64)
if not np.all(np.equal(sums, dsums)):
print "Sum discrepancy!"
print sums
print dsums
2010-10-07 11:21:43 -04:00
class MWCRNGFloatsTest(object):
"""
Note this only tests that the distributions are in the correct range, *not*
that they have good random properties. MWC is a suitable algorithm, but
implementation bugs may still lead to poor performance.
"""
rounds = 1024
entry_name = 'MWC_RNG_floats_test'
def deps(self):
return [MWCRNG]
def module_setup(self):
mem.global_.f32('mwc_rng_float_01_test_sums', ctx.nthreads)
mem.global_.f32('mwc_rng_float_01_test_mins', ctx.nthreads)
mem.global_.f32('mwc_rng_float_01_test_maxs', ctx.nthreads)
mem.global_.f32('mwc_rng_float_11_test_sums', ctx.nthreads)
mem.global_.f32('mwc_rng_float_11_test_mins', ctx.nthreads)
mem.global_.f32('mwc_rng_float_11_test_maxs', ctx.nthreads)
def loop(self, kind):
with block('Sum %d floats in %s' % (self.rounds, kind)):
reg.f32('loopct val rsum rmin rmax')
reg.pred('p_done')
op.mov.f32(loopct, 0.)
op.mov.f32(rsum, 0.)
op.mov.f32(rmin, 2.)
op.mov.f32(rmax, -2.)
label('loopstart' + kind)
getattr(mwc, 'next_f32_' + kind)(val)
op.add.f32(rsum, rsum, val)
op.min.f32(rmin, rmin, val)
op.max.f32(rmax, rmax, val)
op.add.f32(loopct, loopct, 1.)
op.setp.ge.f32(p_done, loopct, float(self.rounds))
op.bra('loopstart' + kind, ifnotp=p_done)
op.mul.f32(rsum, rsum, 1./self.rounds)
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, rsum,
'mwc_rng_float_%s_test_mins' % kind, rmin,
'mwc_rng_float_%s_test_maxs' % kind, rmax)
def entry(self):
self.loop('01')
self.loop('11')
def call_teardown(self, ctx):
# Tolerance of all-threads averages
tol = 0.05
# float distribution kind, test kind, expected value, limit func
tests = [
('01', 'sums', 0.5, None),
('01', 'mins', 0.0, np.min),
('01', 'maxs', 1.0, np.max),
('11', 'sums', 0.0, None),
('11', 'mins', -1.0, np.min),
('11', 'maxs', 1.0, np.max)
]
for fkind, rkind, exp, lim in tests:
2010-09-12 13:45:55 -04:00
name = 'mwc_rng_float_%s_test_%s' % (fkind, rkind)
vals = ctx.get_per_thread(name, np.float32)
avg = np.mean(vals)
if np.abs(avg - exp) > tol:
raise PTXTestFailure("%s %s %g too far from %g" %
(fkind, rkind, avg, exp))
if lim is None: continue
if lim([lim(vals), exp]) != exp:
raise PTXTestFailure("%s %s %g violates hard limit %g" %
(fkind, rkind, lim(vals), exp))
2010-09-10 18:01:50 -04:00
2010-10-07 11:21:43 -04:00
class CPDataStream(object):
"""DataStream which stores the control points."""
shortname = 'cp'
2010-08-28 16:56:05 -04:00
2010-10-07 11:21:43 -04:00
class Timeouter(object):
"""Time-out infinite loops so that data can still be retrieved."""
shortname = 'timeout'
def entry_setup(self):
mem.shared.u64('s_timeouter_start_time')
with block("Load start time for this block"):
reg.u64('now')
op.mov.u64(now, '%clock64')
op.st.shared.u64(addr(s_timeouter_start_time), now)
def check_time(self, secs):
"""
Drop this into your mainloop somewhere.
"""
# TODO: if debug.device_timeout_loops or whatever
with block("Check current time for this loop"):
d = cuda.Context.get_device()
clks = int(secs * d.clock_rate * 1000)
reg.u64('now then')
op.mov.u64(now, '%clock64')
op.ld.shared.u64(then, addr(s_timeouter_start_time))
op.sub.u64(now, now, then)
std.asrt("Loop timed out", 'lt.u64', now, clks)