Fix MWCRNGTest.

This commit is contained in:
Steven Robertson 2010-09-10 18:01:50 -04:00
parent 36f1c1c056
commit c3d12d07c2

View File

@ -520,37 +520,42 @@ class MWCRNGTest(PTXTest):
def call_setup(self, ctx):
# Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
self.mults = cuda.from_device(multdp, ctx.threads, np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
sums = np.zeros(ctx.threads, np.uint64)
self.fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
self.sums = np.zeros(ctx.threads, np.uint64)
print "Running %d states forward %d rounds" % (len(mults), self.rounds)
print "Running %d states forward %d rounds" % \
(len(self.mults), self.rounds)
ctime = time.time()
for i in range(self.rounds):
states = fullstates & 0xffffffff
carries = fullstates >> 32
fullstates = mults * states + carries
sums = sums + (fullstates & 0xffffffff)
states = self.fullstates & 0xffffffff
carries = self.fullstates >> 32
self.fullstates = self.mults * states + carries
self.sums += self.fullstates & 0xffffffff
ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime
def call_teardown(self, ctx):
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
statedp, statel = ctx.mod.get_global('mwc_rng_state')
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all():
if not (dfullstates == self.fullstates).all():
print "State discrepancy"
print dfullstates
print fullstates
print self.fullstates
raise PTXTestFailure("MWC RNG state discrepancy")
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
if not (dsums == sums).all():
if not (dsums == self.sums).all():
print "Sum discrepancy"
print dsums
print sums
print self.sums
raise PTXTestFailure("MWC RNG sum discrepancy")
class CPDataStream(DataStream):
"""DataStream which stores the control points."""
shortname = 'cp'