mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Fix MWCRNGTest.
This commit is contained in:
parent
36f1c1c056
commit
c3d12d07c2
@ -520,37 +520,42 @@ class MWCRNGTest(PTXTest):
|
||||
def call_setup(self, ctx):
|
||||
# Get current multipliers and seeds from the device
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
|
||||
self.mults = cuda.from_device(multdp, ctx.threads, np.uint32)
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
|
||||
sums = np.zeros(ctx.threads, np.uint64)
|
||||
self.fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
|
||||
self.sums = np.zeros(ctx.threads, np.uint64)
|
||||
|
||||
print "Running %d states forward %d rounds" % (len(mults), self.rounds)
|
||||
print "Running %d states forward %d rounds" % \
|
||||
(len(self.mults), self.rounds)
|
||||
ctime = time.time()
|
||||
for i in range(self.rounds):
|
||||
states = fullstates & 0xffffffff
|
||||
carries = fullstates >> 32
|
||||
fullstates = mults * states + carries
|
||||
sums = sums + (fullstates & 0xffffffff)
|
||||
states = self.fullstates & 0xffffffff
|
||||
carries = self.fullstates >> 32
|
||||
self.fullstates = self.mults * states + carries
|
||||
self.sums += self.fullstates & 0xffffffff
|
||||
ctime = time.time() - ctime
|
||||
print "Done on host, took %g seconds" % ctime
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
|
||||
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
|
||||
if not (dfullstates == fullstates).all():
|
||||
if not (dfullstates == self.fullstates).all():
|
||||
print "State discrepancy"
|
||||
print dfullstates
|
||||
print fullstates
|
||||
print self.fullstates
|
||||
raise PTXTestFailure("MWC RNG state discrepancy")
|
||||
|
||||
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
|
||||
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
|
||||
if not (dsums == sums).all():
|
||||
if not (dsums == self.sums).all():
|
||||
print "Sum discrepancy"
|
||||
print dsums
|
||||
print sums
|
||||
print self.sums
|
||||
raise PTXTestFailure("MWC RNG sum discrepancy")
|
||||
|
||||
|
||||
class CPDataStream(DataStream):
|
||||
"""DataStream which stores the control points."""
|
||||
shortname = 'cp'
|
||||
|
Loading…
Reference in New Issue
Block a user