mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Add device assertions to standard library.
This commit is contained in:
parent
3932412539
commit
56404b629f
@ -579,22 +579,22 @@ class MWCRNGFloatsTest(PTXTest):
|
|||||||
@ptx_func
|
@ptx_func
|
||||||
def loop(self, kind):
|
def loop(self, kind):
|
||||||
with block('Sum %d floats in %s' % (self.rounds, kind)):
|
with block('Sum %d floats in %s' % (self.rounds, kind)):
|
||||||
reg.f32('loopct val sum rmin rmax')
|
reg.f32('loopct val rsum rmin rmax')
|
||||||
reg.pred('p_done')
|
reg.pred('p_done')
|
||||||
op.mov.f32(loopct, 0.)
|
op.mov.f32(loopct, 0.)
|
||||||
op.mov.f32(sum, 0.)
|
op.mov.f32(rsum, 0.)
|
||||||
op.mov.f32(rmin, 2.)
|
op.mov.f32(rmin, 2.)
|
||||||
op.mov.f32(rmax, -2.)
|
op.mov.f32(rmax, -2.)
|
||||||
label('loopstart' + kind)
|
label('loopstart' + kind)
|
||||||
getattr(mwc, 'next_f32_' + kind)(val)
|
getattr(mwc, 'next_f32_' + kind)(val)
|
||||||
op.add.f32(sum, sum, val)
|
op.add.f32(rsum, rsum, val)
|
||||||
op.min.f32(rmin, rmin, val)
|
op.min.f32(rmin, rmin, val)
|
||||||
op.max.f32(rmax, rmax, val)
|
op.max.f32(rmax, rmax, val)
|
||||||
op.add.f32(loopct, loopct, 1.)
|
op.add.f32(loopct, loopct, 1.)
|
||||||
op.setp.ge.f32(p_done, loopct, float(self.rounds))
|
op.setp.ge.f32(p_done, loopct, float(self.rounds))
|
||||||
op.bra('loopstart' + kind, ifnotp=p_done)
|
op.bra('loopstart' + kind, ifnotp=p_done)
|
||||||
op.mul.f32(sum, sum, 1./self.rounds)
|
op.mul.f32(rsum, rsum, 1./self.rounds)
|
||||||
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, sum,
|
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, rsum,
|
||||||
'mwc_rng_float_%s_test_mins' % kind, rmin,
|
'mwc_rng_float_%s_test_mins' % kind, rmin,
|
||||||
'mwc_rng_float_%s_test_maxs' % kind, rmax)
|
'mwc_rng_float_%s_test_maxs' % kind, rmax)
|
||||||
|
|
||||||
|
111
cuburn/ptx.py
111
cuburn/ptx.py
@ -16,6 +16,9 @@ from cStringIO import StringIO
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from math import *
|
from math import *
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
|
||||||
# Okay, so here's what's going on.
|
# Okay, so here's what's going on.
|
||||||
#
|
#
|
||||||
# We're using Python to create PTX. If we just use Python to make one giant PTX
|
# We're using Python to create PTX. If we just use Python to make one giant PTX
|
||||||
@ -642,8 +645,10 @@ class PTXEntryPoint(PTXFragment):
|
|||||||
"""
|
"""
|
||||||
ctx.call_setup(self)
|
ctx.call_setup(self)
|
||||||
func = ctx.mod.get_function(self.entry_name)
|
func = ctx.mod.get_function(self.entry_name)
|
||||||
self._call(ctx, func, *args, **kwargs)
|
try:
|
||||||
return ctx.call_teardown(self)
|
self._call(ctx, func, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
return ctx.call_teardown(self)
|
||||||
|
|
||||||
class PTXTestFailure(Exception): pass
|
class PTXTestFailure(Exception): pass
|
||||||
|
|
||||||
@ -663,6 +668,7 @@ class _PTXStdLib(PTXFragment):
|
|||||||
def __init__(self, block):
|
def __init__(self, block):
|
||||||
# Only module that gets the privilege of seeing 'block' directly.
|
# Only module that gets the privilege of seeing 'block' directly.
|
||||||
self.block = block
|
self.block = block
|
||||||
|
self.asserts = ["Success"]
|
||||||
|
|
||||||
def deps(self):
|
def deps(self):
|
||||||
return []
|
return []
|
||||||
@ -673,6 +679,7 @@ class _PTXStdLib(PTXFragment):
|
|||||||
# multiple devices first, which we definitely do not yet do
|
# multiple devices first, which we definitely do not yet do
|
||||||
self.block.code(prefix='.version 2.1', semi=False)
|
self.block.code(prefix='.version 2.1', semi=False)
|
||||||
self.block.code(prefix='.target sm_21', semi=False)
|
self.block.code(prefix='.target sm_21', semi=False)
|
||||||
|
mem.global_.u32('g_std_exit_err', ctx.threads)
|
||||||
|
|
||||||
@ptx_func
|
@ptx_func
|
||||||
def get_gtid(self, dst):
|
def get_gtid(self, dst):
|
||||||
@ -716,6 +723,106 @@ class _PTXStdLib(PTXFragment):
|
|||||||
def not_(self, pred):
|
def not_(self, pred):
|
||||||
return ['!', pred]
|
return ['!', pred]
|
||||||
|
|
||||||
|
@ptx_func
|
||||||
|
def asrt(self, msg, o=None, a=None, b=None, p=None, notp=None,
|
||||||
|
ret=False, ign=False, lvl=1):
|
||||||
|
"""
|
||||||
|
Device assertion.
|
||||||
|
|
||||||
|
Without arguments, a thread will log the error code associated with
|
||||||
|
``msg`` and issue a trap instruction, which will cause the device to
|
||||||
|
terminate execution in all threads immediately. Any of the options
|
||||||
|
below modify that behavior, as described.
|
||||||
|
|
||||||
|
``o``, ``a`` and ``b``, when set together, will be used to create a
|
||||||
|
``setp`` instruction to test a condition. They're the first three
|
||||||
|
arguments, to make usage a bit more natural:
|
||||||
|
|
||||||
|
>>> std.asrt('lt.u32', val, 0)
|
||||||
|
|
||||||
|
This would generate the instruction ``setp.lt.u32 <p>, val, 0;``
|
||||||
|
(<p> is created by this function). The thread would only store the
|
||||||
|
error code and exit if the condition were *false*.
|
||||||
|
|
||||||
|
``p`` is a predicate value; the store and trap will happen if it is
|
||||||
|
*not* set (same sense as ``o`` and Python's assert). ``notp`` is the
|
||||||
|
reverse.
|
||||||
|
|
||||||
|
Only one of ``o``, ``ifp``, or ``ifnotp`` can be set per call.
|
||||||
|
|
||||||
|
``ret`` causes the assert to issue a ``ret;`` instruction in place of
|
||||||
|
the trap. This causes the current thread to terminate, but does not
|
||||||
|
cause the other threads to do so. Be cautious, as barriers can cause a
|
||||||
|
kernel to hang using this instruction.
|
||||||
|
|
||||||
|
``ign`` causes the error code to be stored, but does not terminate
|
||||||
|
thread execution ("ignores" the error). This is useful to identify the
|
||||||
|
location of all threads in case of an abnormal termination caused by
|
||||||
|
another thread, and is used to set up the entry-wide "early
|
||||||
|
termination" error. ``ign`` overrides ``ret``.
|
||||||
|
|
||||||
|
This code calculates the gtid unconditionally, and so can be relatively
|
||||||
|
expensive to insert into a tight loop. As a result, assert
|
||||||
|
statements will only be added if the debug value ``assert_level`` is
|
||||||
|
at least as large as the ``lvl`` argument.
|
||||||
|
"""
|
||||||
|
# TODO: debug level checking
|
||||||
|
if np.sum(map(bool, (o, p, notp))) > 1:
|
||||||
|
raise ValueError("Can only use one of o, ifp, ifnotp.")
|
||||||
|
if msg not in self.asserts:
|
||||||
|
self.asserts.append(msg)
|
||||||
|
err_code = self.asserts.index(msg)
|
||||||
|
with block("Assertion: " + msg):
|
||||||
|
reg.u32('asrt_base asrt_off')
|
||||||
|
op.mov.u32(asrt_base, g_std_exit_err)
|
||||||
|
self.get_gtid(asrt_off)
|
||||||
|
op.mad.lo.u32(asrt_base, asrt_off, 4, asrt_base)
|
||||||
|
realp = None
|
||||||
|
if o:
|
||||||
|
realp = self.not_(reg.pred('p_asrt_fail'))
|
||||||
|
if a is None or b is None:
|
||||||
|
raise ValueError("Must specify ``a`` and ``b`` with ``o``.")
|
||||||
|
op._call(['setp.'+o], p_asrt_fail, a, b)
|
||||||
|
if p:
|
||||||
|
realp = self.not_(p)
|
||||||
|
if notp:
|
||||||
|
realp = notp
|
||||||
|
op.st.global_.u32(addr(asrt_base), err_code, ifp=realp)
|
||||||
|
if not ign:
|
||||||
|
if ret:
|
||||||
|
op.ret(ifp=realp)
|
||||||
|
else:
|
||||||
|
op.trap(ifp=realp)
|
||||||
|
|
||||||
|
@ptx_func
|
||||||
|
def entry_setup(self):
|
||||||
|
self.asrt("Unexpected thread exit", ign=True, lvl=0)
|
||||||
|
|
||||||
|
@ptx_func
|
||||||
|
def entry_teardown(self):
|
||||||
|
self.asrt(self.asserts[0], ret=True, lvl=0)
|
||||||
|
|
||||||
|
def call_teardown(self, ctx):
|
||||||
|
"""
|
||||||
|
This function raises an exception if all cleanup code wasn't called on
|
||||||
|
the device. To suppress this - for instance, to inspect data from a
|
||||||
|
partially-executed thread - do
|
||||||
|
|
||||||
|
>>> std.asrt(std.asserts[0], ign=True, lvl=0)
|
||||||
|
|
||||||
|
at the start of your entry. Yes, it's a hacky solution.
|
||||||
|
"""
|
||||||
|
dp, l = ctx.mod.get_global('g_std_exit_err')
|
||||||
|
errs = cuda.from_device(dp, ctx.threads, np.uint32)
|
||||||
|
if np.sum(errs) != 0:
|
||||||
|
print "Some threads terminated unsuccessfully."
|
||||||
|
for i, msg in enumerate(self.asserts):
|
||||||
|
count = sum(np.equal(errs, i))
|
||||||
|
if count:
|
||||||
|
print '%6d said "%s".' % (count, msg)
|
||||||
|
print
|
||||||
|
raise EnvironmentError("Abnormal thread termination")
|
||||||
|
|
||||||
def to_inject(self):
|
def to_inject(self):
|
||||||
# Set up the initial namespace
|
# Set up the initial namespace
|
||||||
return dict(
|
return dict(
|
||||||
|
Loading…
Reference in New Issue
Block a user