mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Test to make sure floating point numbers were in the right range.
This commit is contained in:
parent
e71a8422e5
commit
3932412539
2
bench.py
2
bench.py
@ -116,7 +116,7 @@ class L2WriteCombining(PTXTest):
|
||||
print "Bytes for coa was %g ± %g" % pm(self.times_bytes[1])
|
||||
print "Clks for uncoa was %g ± %g" % pm(self.times_bytes[2])
|
||||
print "Bytes for uncoa was %g ± %g" % pm(self.times_bytes[3])
|
||||
print ''
|
||||
print
|
||||
|
||||
def printover(a, r, s=1):
|
||||
for i in range(0, len(a), r*s):
|
||||
|
@ -104,7 +104,7 @@ class LaunchContext(object):
|
||||
for test in self.ptx.tests:
|
||||
cuda.Context.synchronize()
|
||||
if test.call(self):
|
||||
print "Test %s passed." % test.entry_name
|
||||
print "Test %s passed.\n" % test.entry_name
|
||||
else:
|
||||
all_okay = False
|
||||
return all_okay
|
||||
|
@ -478,7 +478,7 @@ class MWCRNG(PTXFragment):
|
||||
self.seed(ctx)
|
||||
|
||||
def tests(self):
|
||||
return [MWCRNGTest]
|
||||
return [MWCRNGTest, MWCRNGFloatsTest]
|
||||
|
||||
class MWCRNGTest(PTXTest):
|
||||
name = "MWC RNG sum-of-threads"
|
||||
@ -555,6 +555,79 @@ class MWCRNGTest(PTXTest):
|
||||
print self.sums
|
||||
raise PTXTestFailure("MWC RNG sum discrepancy")
|
||||
|
||||
class MWCRNGFloatsTest(PTXTest):
|
||||
"""
|
||||
Note this only tests that the distributions are in the correct range, *not*
|
||||
that they have good random properties. MWC is a suitable algorithm, but
|
||||
implementation bugs may still lead to poor performance.
|
||||
"""
|
||||
rounds = 1024
|
||||
entry_name = 'MWC_RNG_floats_test'
|
||||
|
||||
def deps(self):
|
||||
return [MWCRNG]
|
||||
|
||||
@ptx_func
|
||||
def module_setup(self):
|
||||
mem.global_.f32('mwc_rng_float_01_test_sums', ctx.threads)
|
||||
mem.global_.f32('mwc_rng_float_01_test_mins', ctx.threads)
|
||||
mem.global_.f32('mwc_rng_float_01_test_maxs', ctx.threads)
|
||||
mem.global_.f32('mwc_rng_float_11_test_sums', ctx.threads)
|
||||
mem.global_.f32('mwc_rng_float_11_test_mins', ctx.threads)
|
||||
mem.global_.f32('mwc_rng_float_11_test_maxs', ctx.threads)
|
||||
|
||||
@ptx_func
|
||||
def loop(self, kind):
|
||||
with block('Sum %d floats in %s' % (self.rounds, kind)):
|
||||
reg.f32('loopct val sum rmin rmax')
|
||||
reg.pred('p_done')
|
||||
op.mov.f32(loopct, 0.)
|
||||
op.mov.f32(sum, 0.)
|
||||
op.mov.f32(rmin, 2.)
|
||||
op.mov.f32(rmax, -2.)
|
||||
label('loopstart' + kind)
|
||||
getattr(mwc, 'next_f32_' + kind)(val)
|
||||
op.add.f32(sum, sum, val)
|
||||
op.min.f32(rmin, rmin, val)
|
||||
op.max.f32(rmax, rmax, val)
|
||||
op.add.f32(loopct, loopct, 1.)
|
||||
op.setp.ge.f32(p_done, loopct, float(self.rounds))
|
||||
op.bra('loopstart' + kind, ifnotp=p_done)
|
||||
op.mul.f32(sum, sum, 1./self.rounds)
|
||||
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, sum,
|
||||
'mwc_rng_float_%s_test_mins' % kind, rmin,
|
||||
'mwc_rng_float_%s_test_maxs' % kind, rmax)
|
||||
|
||||
@ptx_func
|
||||
def entry(self):
|
||||
self.loop('01')
|
||||
self.loop('11')
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
# Tolerance of all-threads averages
|
||||
tol = 0.05
|
||||
# float distribution kind, test kind, expected value, limit func
|
||||
tests = [
|
||||
('01', 'sums', 0.5, None),
|
||||
('01', 'mins', 0.0, np.min),
|
||||
('01', 'maxs', 1.0, np.max),
|
||||
('11', 'sums', 0.0, None),
|
||||
('11', 'mins', -1.0, np.min),
|
||||
('11', 'maxs', 1.0, np.max)
|
||||
]
|
||||
|
||||
for fkind, rkind, exp, lim in tests:
|
||||
dp, l = ctx.mod.get_global(
|
||||
'mwc_rng_float_%s_test_%s' % (fkind, rkind))
|
||||
vals = cuda.from_device(dp, ctx.threads, np.float32)
|
||||
avg = np.mean(vals)
|
||||
if np.abs(avg - exp) > tol:
|
||||
raise PTXTestFailure("%s %s %g too far from %g" %
|
||||
(fkind, rkind, avg, exp))
|
||||
if lim is None: continue
|
||||
if lim([lim(vals), exp]) != exp:
|
||||
raise PTXTestFailure("%s %s %g violates hard limit %g" %
|
||||
(fkind, rkind, lim(vals), exp))
|
||||
|
||||
class CPDataStream(DataStream):
|
||||
"""DataStream which stores the control points."""
|
||||
|
Loading…
Reference in New Issue
Block a user