diff --git a/cuburn/device_code.py b/cuburn/device_code.py index 3f2e95f..b851582 100644 --- a/cuburn/device_code.py +++ b/cuburn/device_code.py @@ -520,37 +520,42 @@ class MWCRNGTest(PTXTest): def call_setup(self, ctx): # Get current multipliers and seeds from the device multdp, multl = ctx.mod.get_global('mwc_rng_mults') - mults = cuda.from_device(multdp, ctx.threads, np.uint32) + self.mults = cuda.from_device(multdp, ctx.threads, np.uint32) statedp, statel = ctx.mod.get_global('mwc_rng_state') - fullstates = cuda.from_device(statedp, ctx.threads, np.uint64) - sums = np.zeros(ctx.threads, np.uint64) + self.fullstates = cuda.from_device(statedp, ctx.threads, np.uint64) + self.sums = np.zeros(ctx.threads, np.uint64) - print "Running %d states forward %d rounds" % (len(mults), self.rounds) + print "Running %d states forward %d rounds" % \ + (len(self.mults), self.rounds) ctime = time.time() for i in range(self.rounds): - states = fullstates & 0xffffffff - carries = fullstates >> 32 - fullstates = mults * states + carries - sums = sums + (fullstates & 0xffffffff) + states = self.fullstates & 0xffffffff + carries = self.fullstates >> 32 + self.fullstates = self.mults * states + carries + self.sums += self.fullstates & 0xffffffff ctime = time.time() - ctime print "Done on host, took %g seconds" % ctime def call_teardown(self, ctx): + multdp, multl = ctx.mod.get_global('mwc_rng_mults') + statedp, statel = ctx.mod.get_global('mwc_rng_state') + dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64) - if not (dfullstates == fullstates).all(): + if not (dfullstates == self.fullstates).all(): print "State discrepancy" print dfullstates - print fullstates + print self.fullstates raise PTXTestFailure("MWC RNG state discrepancy") sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums') dsums = cuda.from_device(sumdp, ctx.threads, np.uint64) - if not (dsums == sums).all(): + if not (dsums == self.sums).all(): print "Sum discrepancy" print dsums - print sums + print self.sums raise PTXTestFailure("MWC RNG sum discrepancy") + class CPDataStream(DataStream): """DataStream which stores the control points.""" shortname = 'cp'