Random floats (I think)

This commit is contained in:
Steven Robertson
2010-09-06 14:19:06 -04:00
parent f3298e0bed
commit ada0fe20c7
2 changed files with 69 additions and 33 deletions

View File

@ -40,10 +40,10 @@ class IterThread(PTXTest):
op.mov.u32(num_writes, 0)
# TODO: MWC float output types
#mwc_next_f32_01(x_coord)
#mwc_next_f32_01(y_coord)
#mwc_next_f32_01(color_coord)
#mwc_next_f32_01(alpha_coord)
mwc_next_f32_01(x_coord)
mwc_next_f32_01(y_coord)
mwc_next_f32_01(color_coord)
mwc_next_f32_01(alpha_coord)
# Registers are hard to come by. To avoid having to track both the count
# of samples processed and the number of samples to generate,
@ -189,17 +189,40 @@ class MWCRNG(PTXFragment):
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
@ptx_func
def _next(self):
# Call from inside a block!
reg.u64('mwc_out')
op.cvt.u64.u32(mwc_out, mwc_car)
op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out)
op.mov.b64(vec(mwc_st, mwc_car), mwc_out)
@ptx_func
def next_b32(self, dst_reg):
with block('Load next random into ' + dst_reg.name):
reg.u64('mwc_out')
op.cvt.u64.u32(mwc_out, mwc_car)
op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out)
op.mov.b64(vec(mwc_st, mwc_car), mwc_out)
with block('Load next random u32 into ' + dst_reg.name):
self._next()
op.mov.u32(dst_reg, mwc_st)
@ptx_func
def next_f32_01(self, dst_reg):
# TODO: verify that this is the fastest-performance method
# TODO: verify that this actually does what I think it does
with block('Load random float [0,1] into ' + dst_reg.name):
self._next()
op.cvt.rn.f32.u32(dst_reg, mwc_st)
op.mul.f32(dst_reg, dst_reg, '0f0000802F') # 1./(1<<32)
@ptx_func
def next_f32_11(self, dst_reg):
with block('Load random float [-1,1) into ' + dst_reg.name):
self._next()
op.cvt.rn.f32.s32(dst_reg, mwc_st)
op.mul.lo.f32(dst_reg, dst_reg, '0f00000030') # 1./(1<<31)
def to_inject(self):
return dict(mwc_next_b32=self.next_b32)
return dict(mwc_next_b32=self.next_b32,
mwc_next_f32_01=self.next_f32_01,
mwc_next_f32_11=self.next_f32_11)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads: