Fix MWCRNGTest.

This commit is contained in:
Steven Robertson 2010-09-10 18:01:50 -04:00
parent 36f1c1c056
commit c3d12d07c2

View File

@ -520,37 +520,42 @@ class MWCRNGTest(PTXTest):
def call_setup(self, ctx): def call_setup(self, ctx):
# Get current multipliers and seeds from the device # Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults') multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32) self.mults = cuda.from_device(multdp, ctx.threads, np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state') statedp, statel = ctx.mod.get_global('mwc_rng_state')
fullstates = cuda.from_device(statedp, ctx.threads, np.uint64) self.fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
sums = np.zeros(ctx.threads, np.uint64) self.sums = np.zeros(ctx.threads, np.uint64)
print "Running %d states forward %d rounds" % (len(mults), self.rounds) print "Running %d states forward %d rounds" % \
(len(self.mults), self.rounds)
ctime = time.time() ctime = time.time()
for i in range(self.rounds): for i in range(self.rounds):
states = fullstates & 0xffffffff states = self.fullstates & 0xffffffff
carries = fullstates >> 32 carries = self.fullstates >> 32
fullstates = mults * states + carries self.fullstates = self.mults * states + carries
sums = sums + (fullstates & 0xffffffff) self.sums += self.fullstates & 0xffffffff
ctime = time.time() - ctime ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime print "Done on host, took %g seconds" % ctime
def call_teardown(self, ctx): def call_teardown(self, ctx):
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
statedp, statel = ctx.mod.get_global('mwc_rng_state')
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64) dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all(): if not (dfullstates == self.fullstates).all():
print "State discrepancy" print "State discrepancy"
print dfullstates print dfullstates
print fullstates print self.fullstates
raise PTXTestFailure("MWC RNG state discrepancy") raise PTXTestFailure("MWC RNG state discrepancy")
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums') sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64) dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
if not (dsums == sums).all(): if not (dsums == self.sums).all():
print "Sum discrepancy" print "Sum discrepancy"
print dsums print dsums
print sums print self.sums
raise PTXTestFailure("MWC RNG sum discrepancy") raise PTXTestFailure("MWC RNG sum discrepancy")
class CPDataStream(DataStream): class CPDataStream(DataStream):
"""DataStream which stores the control points.""" """DataStream which stores the control points."""
shortname = 'cp' shortname = 'cp'