mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Fix MWCRNGTest.
This commit is contained in:
		@ -520,37 +520,42 @@ class MWCRNGTest(PTXTest):
 | 
				
			|||||||
    def call_setup(self, ctx):
 | 
					    def call_setup(self, ctx):
 | 
				
			||||||
        # Get current multipliers and seeds from the device
 | 
					        # Get current multipliers and seeds from the device
 | 
				
			||||||
        multdp, multl = ctx.mod.get_global('mwc_rng_mults')
 | 
					        multdp, multl = ctx.mod.get_global('mwc_rng_mults')
 | 
				
			||||||
        mults = cuda.from_device(multdp, ctx.threads, np.uint32)
 | 
					        self.mults = cuda.from_device(multdp, ctx.threads, np.uint32)
 | 
				
			||||||
        statedp, statel = ctx.mod.get_global('mwc_rng_state')
 | 
					        statedp, statel = ctx.mod.get_global('mwc_rng_state')
 | 
				
			||||||
        fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
 | 
					        self.fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
 | 
				
			||||||
        sums = np.zeros(ctx.threads, np.uint64)
 | 
					        self.sums = np.zeros(ctx.threads, np.uint64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print "Running %d states forward %d rounds" % (len(mults), self.rounds)
 | 
					        print "Running %d states forward %d rounds" % \
 | 
				
			||||||
 | 
					              (len(self.mults), self.rounds)
 | 
				
			||||||
        ctime = time.time()
 | 
					        ctime = time.time()
 | 
				
			||||||
        for i in range(self.rounds):
 | 
					        for i in range(self.rounds):
 | 
				
			||||||
            states = fullstates & 0xffffffff
 | 
					            states = self.fullstates & 0xffffffff
 | 
				
			||||||
            carries = fullstates >> 32
 | 
					            carries = self.fullstates >> 32
 | 
				
			||||||
            fullstates = mults * states + carries
 | 
					            self.fullstates = self.mults * states + carries
 | 
				
			||||||
            sums = sums + (fullstates & 0xffffffff)
 | 
					            self.sums += self.fullstates & 0xffffffff
 | 
				
			||||||
        ctime = time.time() - ctime
 | 
					        ctime = time.time() - ctime
 | 
				
			||||||
        print "Done on host, took %g seconds" % ctime
 | 
					        print "Done on host, took %g seconds" % ctime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def call_teardown(self, ctx):
 | 
					    def call_teardown(self, ctx):
 | 
				
			||||||
 | 
					        multdp, multl = ctx.mod.get_global('mwc_rng_mults')
 | 
				
			||||||
 | 
					        statedp, statel = ctx.mod.get_global('mwc_rng_state')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
 | 
					        dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
 | 
				
			||||||
        if not (dfullstates == fullstates).all():
 | 
					        if not (dfullstates == self.fullstates).all():
 | 
				
			||||||
            print "State discrepancy"
 | 
					            print "State discrepancy"
 | 
				
			||||||
            print dfullstates
 | 
					            print dfullstates
 | 
				
			||||||
            print fullstates
 | 
					            print self.fullstates
 | 
				
			||||||
            raise PTXTestFailure("MWC RNG state discrepancy")
 | 
					            raise PTXTestFailure("MWC RNG state discrepancy")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
 | 
					        sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
 | 
				
			||||||
        dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
 | 
					        dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
 | 
				
			||||||
        if not (dsums == sums).all():
 | 
					        if not (dsums == self.sums).all():
 | 
				
			||||||
            print "Sum discrepancy"
 | 
					            print "Sum discrepancy"
 | 
				
			||||||
            print dsums
 | 
					            print dsums
 | 
				
			||||||
            print sums
 | 
					            print self.sums
 | 
				
			||||||
            raise PTXTestFailure("MWC RNG sum discrepancy")
 | 
					            raise PTXTestFailure("MWC RNG sum discrepancy")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CPDataStream(DataStream):
 | 
					class CPDataStream(DataStream):
 | 
				
			||||||
    """DataStream which stores the control points."""
 | 
					    """DataStream which stores the control points."""
 | 
				
			||||||
    shortname = 'cp'
 | 
					    shortname = 'cp'
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user