cuburn/cuburnlib/device_code.py
2010-09-02 17:26:16 -04:00

266 lines
8.9 KiB
Python

"""
Contains the PTX fragments which will drive the device.
"""
import os
import time
import pycuda.driver as cuda
import numpy as np
from cuburnlib.ptx import *
"""
Here's the current draft of the full algorithm implementation.
declare xform jump table
load random state
clear x_coord, y_coord, z_coord, w_coord;
store -(FUSE+1) to shared (per-warp) num_samples_sh
clear badvals [1]
load param (global_cp_idx_addr)
index table start (global_cp_idx) [2]
load count of indexes from global cp index =>
store to qlocal current_cp_num [3]
outermost loop start:
load current_cp_num
if current_cp_num <= 0:
exit
load param global_cp_idx_addr
calculate offset into address with current_cp_num, global_cp_idx_addr
load cp_base_address
stream_start (cp_base, cp_base_addr) [4]
FUSE_START:
num_samples += 1
if num_samples >= 0:
# Okay, we're done FUSEing, prepare to enter normal loop
load num_samples => store to shared (per-warp) num_samples
ITER_LOOP_START:
reg xform_addr, xform_stream_addr, xform_select
mwc_next_u32 to xform_select
# Performance test: roll/unroll this loop?
stream_load xform_prob (cp_stream)
if xform_select <= xform_prob:
bra.uni XFORM_1_LBL
...
stream_load xform_prob (cp_stream)
if xform_select <= xform_prob:
bra.uni XFORM_N_LBL
XFORM_1_LBL:
stream_load xform_1_ (cp_stream)
...
bra.uni XFORM_POST
XFORM_POST:
[if final_xform:]
[do final_xform]
if num_samples < 0:
# FUSE still in progress
bra.uni FUSE_START
FRAGMENT_WRITEBACK:
# Unknown at this time.
SHUFFLE:
# Unknown at this time.
load num_samples from num_samples_sh
num_samples -= 1
if num_samples > 0:
bra.uni ITER_LOOP_START
[1] Tracking 'badvals' can put a pretty large hit on performance, particularly
for images that sample a small amount of the grid. So this might be cut
when rendering for performance. On the other hand, it might actually help
tune the algorithm later, so it'll definitely be an option.
[2] Control points for each temporal sample will be preloaded to the
device in the compact DataStream format (more on this later). Their
locations are represented in an index table, which starts with a single
`.u32 length`, followed by `length` pointers. To avoid having to keep
reloading `length`, or worse, using a register to hold it in memory, we
instead count *down* to zero. This is a very common idiom.
[3] 'qlocal' is quasi-local storage. it could easily be actual local storage,
depending on how local storage is implemented, but the extra 128-byte loads
for such values might make a performance difference. qlocal variables may
be identical across a warp or even a CTA, and so variables noted as
"qlocal" here might end up in shared memory or even a small per-warp or
per-CTA buffer in global memory created specifically for this purpose,
after benchmarking is done.
[4] DataStreams are "opaque" data serialization structures defined below. The
structure of a stream is actually created while parsing the DSL by the load
statements themselves. Some benchmarks need to be done before DataStreams
stop being "opaque" and become simply "dynamic".
"""
class MWCRNG(PTXFragment):
def __init__(self):
self.threads_ready = 0
if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found')
@ptx_func
def module_setup(self):
mem.global_.u32('mwc_rng_mults', ctx.threads)
mem.global_.u64('mwc_rng_state', ctx.threads)
@ptx_func
def entry_setup(self):
reg.u32('mwc_st mwc_mult mwc_car')
with block('Load MWC multipliers and states'):
reg.u32('mwc_off mwc_addr')
get_gtid(mwc_off)
op.mov.u32(mwc_addr, mwc_rng_mults)
op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr)
op.ld.global_.u32(mwc_mult, addr(mwc_addr))
op.mov.u32(mwc_addr, mwc_rng_state)
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr))
@ptx_func
def entry_teardown(self):
with block('Save MWC states'):
reg.u32('mwc_off mwc_addr')
get_gtid(mwc_off)
op.mov.u32(mwc_addr, mwc_rng_state)
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
@ptx_func
def next_b32(self, dst_reg):
with block('Load next random into ' + dst_reg.name):
reg.u64('mwc_out')
op.cvt.u64.u32(mwc_out, mwc_car)
op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out)
op.mov.b64(vec(mwc_st, mwc_car), mwc_out)
op.mov.u32(dst_reg, mwc_st)
def to_inject(self):
return dict(mwc_next_b32=self.next_b32)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads:
# Already set up enough random states, don't push again
return
# Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B')
mults = np.frombuffer(primefp.read(), dtype=dt)
stream = cuda.Stream()
# Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise.
mults = np.array(mults[:ctx.threads*4])
ctx.rand.shuffle(mults)
# Copy multipliers and seeds to the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
cuda.memcpy_htod_async(multdp, mults.tostring()[:multl])
# Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0
states = np.array(ctx.rand.randint(1, 0xffffffff, size=2*ctx.threads),
dtype=np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
cuda.memcpy_htod_async(statedp, states.tostring())
self.threads_ready = ctx.threads
def tests(self):
return [MWCRNGTest]
class MWCRNGTest(PTXTest):
name = "MWC RNG sum-of-threads"
rounds = 5000
entry_name = 'MWC_RNG_test'
entry_params = ''
def deps(self):
return [MWCRNG]
@ptx_func
def module_setup(self):
mem.global_.u64('mwc_rng_test_sums', ctx.threads)
@ptx_func
def entry(self):
reg.u64('sum addl')
reg.u32('addend')
op.mov.u64(sum, 0)
with block('Sum next %d random numbers' % self.rounds):
reg.u32('loopct')
reg.pred('p')
op.mov.u32(loopct, self.rounds)
label('loopstart')
mwc_next_b32(addend)
op.cvt.u64.u32(addl, addend)
op.add.u64(sum, sum, addl)
op.sub.u32(loopct, loopct, 1)
op.setp.gt.u32(p, loopct, 0)
op.bra.uni(loopstart, ifp=p)
with block('Store sum and state'):
reg.u32('adr offset')
get_gtid(offset)
op.mov.u32(adr, mwc_rng_test_sums)
op.mad.lo.u32(adr, offset, 8, adr)
op.st.global_.u64(addr(adr), sum)
def call(self, ctx):
# Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
sums = np.zeros(ctx.threads, np.uint64)
print "Running %d states forward %d rounds" % (len(mults), self.rounds)
ctime = time.time()
for i in range(self.rounds):
states = fullstates & 0xffffffff
carries = fullstates >> 32
fullstates = mults * states + carries
sums = sums + (fullstates & 0xffffffff)
ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime
func = ctx.mod.get_function('MWC_RNG_test')
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
print "Done on device, took %g seconds (%gx)" % (dtime, ctime/dtime)
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all():
print "State discrepancy"
print dfullstates
print fullstates
return False
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
if not (dsums == sums).all():
print "Sum discrepancy"
print dsums
print sums
return False
return True
class CameraCoordTransform(PTXFragment):
# TODO finish
pass
class CPDataStream(PTXFragment):
"""
DataStream which stores