mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 19:50:04 -05:00
Fix RNG test
This commit is contained in:
parent
a6141f492d
commit
70ca6d7729
@ -688,7 +688,7 @@ class MWCRNGTest(PTXTest):
|
|||||||
def call_setup(self, ctx):
|
def call_setup(self, ctx):
|
||||||
# Get current multipliers and seeds from the device
|
# Get current multipliers and seeds from the device
|
||||||
self.mults = ctx.get_per_thread('mwc_rng_mults', np.uint32)
|
self.mults = ctx.get_per_thread('mwc_rng_mults', np.uint32)
|
||||||
self.fullstates = ctx.get_per_thread('mwc_rng_states', np.uint64)
|
self.fullstates = ctx.get_per_thread('mwc_rng_state', np.uint64)
|
||||||
self.sums = np.zeros(ctx.nthreads, np.uint64)
|
self.sums = np.zeros(ctx.nthreads, np.uint64)
|
||||||
|
|
||||||
print "Running %d states forward %d rounds" % \
|
print "Running %d states forward %d rounds" % \
|
||||||
@ -703,7 +703,7 @@ class MWCRNGTest(PTXTest):
|
|||||||
print "Done on host, took %g seconds" % ctime
|
print "Done on host, took %g seconds" % ctime
|
||||||
|
|
||||||
def call_teardown(self, ctx):
|
def call_teardown(self, ctx):
|
||||||
dfullstates = ctx.get_per_thread('mwc_rng_states', np.uint64)
|
dfullstates = ctx.get_per_thread('mwc_rng_state', np.uint64)
|
||||||
if not (dfullstates == self.fullstates).all():
|
if not (dfullstates == self.fullstates).all():
|
||||||
print "State discrepancy"
|
print "State discrepancy"
|
||||||
print dfullstates
|
print dfullstates
|
||||||
|
Loading…
Reference in New Issue
Block a user