diff --git a/cuburn/code/mwc.py b/cuburn/code/mwc.py index dc12c32..c24726e 100644 --- a/cuburn/code/mwc.py +++ b/cuburn/code/mwc.py @@ -102,14 +102,14 @@ def test_mwc(rounds=5000, nblocks=64, blockwidth=512): print "Trial %d, on CPU: " % trial, sums = np.zeros(nthreads, dtype=np.uint64) ctime = time.time() - mults = seeds[0].astype(np.uint64) - states = seeds[1] - carries = seeds[2] + mults = seeds[:,0].astype(np.uint64) + states = seeds[:,1] + carries = seeds[:,2] for i in range(rounds): step = np.frombuffer((mults * states + carries).data, - dtype=np.uint32).reshape((2, nthreads), order='F') - states[:] = step[0] - carries[:] = step[1] + dtype=np.uint32).reshape((nthreads, 2)) + states[:] = step[:,0] + carries[:] = step[:,1] sums += states ctime = time.time() - ctime