Add device assertions to standard library.

This commit is contained in:
Steven Robertson 2010-09-11 00:12:02 -04:00
parent 3932412539
commit 56404b629f
2 changed files with 114 additions and 7 deletions

View File

@ -579,22 +579,22 @@ class MWCRNGFloatsTest(PTXTest):
@ptx_func
def loop(self, kind):
with block('Sum %d floats in %s' % (self.rounds, kind)):
reg.f32('loopct val sum rmin rmax')
reg.f32('loopct val rsum rmin rmax')
reg.pred('p_done')
op.mov.f32(loopct, 0.)
op.mov.f32(sum, 0.)
op.mov.f32(rsum, 0.)
op.mov.f32(rmin, 2.)
op.mov.f32(rmax, -2.)
label('loopstart' + kind)
getattr(mwc, 'next_f32_' + kind)(val)
op.add.f32(sum, sum, val)
op.add.f32(rsum, rsum, val)
op.min.f32(rmin, rmin, val)
op.max.f32(rmax, rmax, val)
op.add.f32(loopct, loopct, 1.)
op.setp.ge.f32(p_done, loopct, float(self.rounds))
op.bra('loopstart' + kind, ifnotp=p_done)
op.mul.f32(sum, sum, 1./self.rounds)
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, sum,
op.mul.f32(rsum, rsum, 1./self.rounds)
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, rsum,
'mwc_rng_float_%s_test_mins' % kind, rmin,
'mwc_rng_float_%s_test_maxs' % kind, rmax)

View File

@ -16,6 +16,9 @@ from cStringIO import StringIO
from collections import namedtuple
from math import *
import numpy as np
import pycuda.driver as cuda
# Okay, so here's what's going on.
#
# We're using Python to create PTX. If we just use Python to make one giant PTX
@ -642,8 +645,10 @@ class PTXEntryPoint(PTXFragment):
"""
ctx.call_setup(self)
func = ctx.mod.get_function(self.entry_name)
self._call(ctx, func, *args, **kwargs)
return ctx.call_teardown(self)
try:
self._call(ctx, func, *args, **kwargs)
finally:
return ctx.call_teardown(self)
class PTXTestFailure(Exception): pass
@ -663,6 +668,7 @@ class _PTXStdLib(PTXFragment):
def __init__(self, block):
# Only module that gets the privilege of seeing 'block' directly.
self.block = block
self.asserts = ["Success"]
def deps(self):
return []
@ -673,6 +679,7 @@ class _PTXStdLib(PTXFragment):
# multiple devices first, which we definitely do not yet do
self.block.code(prefix='.version 2.1', semi=False)
self.block.code(prefix='.target sm_21', semi=False)
mem.global_.u32('g_std_exit_err', ctx.threads)
@ptx_func
def get_gtid(self, dst):
@ -716,6 +723,106 @@ class _PTXStdLib(PTXFragment):
def not_(self, pred):
return ['!', pred]
@ptx_func
def asrt(self, msg, o=None, a=None, b=None, p=None, notp=None,
ret=False, ign=False, lvl=1):
"""
Device assertion.
Without arguments, a thread will log the error code associated with
``msg`` and issue a trap instruction, which will cause the device to
terminate execution in all threads immediately. Any of the options
below modify that behavior, as described.
``o``, ``a`` and ``b``, when set together, will be used to create a
``setp`` instruction to test a condition. They're the first three
arguments, to make usage a bit more natural:
>>> std.asrt('lt.u32', val, 0)
This would generate the instruction ``setp.lt.u32 <p>, val, 0;``
(<p> is created by this function). The thread would only store the
error code and exit if the condition were *false*.
``p`` is a predicate value; the store and trap will happen if it is
*not* set (same sense as ``o`` and Python's assert). ``notp`` is the
reverse.
Only one of ``o``, ``ifp``, or ``ifnotp`` can be set per call.
``ret`` causes the assert to issue a ``ret;`` instruction in place of
the trap. This causes the current thread to terminate, but does not
cause the other threads to do so. Be cautious, as barriers can cause a
kernel to hang using this instruction.
``ign`` causes the error code to be stored, but does not terminate
thread execution ("ignores" the error). This is useful to identify the
location of all threads in case of an abnormal termination caused by
another thread, and is used to set up the entry-wide "early
termination" error. ``ign`` overrides ``ret``.
This code calculates the gtid unconditionally, and so can be relatively
expensive to insert into a tight loop. As a result, assert
statements will only be added if the debug value ``assert_level`` is
at least as large as the ``lvl`` argument.
"""
# TODO: debug level checking
if np.sum(map(bool, (o, p, notp))) > 1:
raise ValueError("Can only use one of o, ifp, ifnotp.")
if msg not in self.asserts:
self.asserts.append(msg)
err_code = self.asserts.index(msg)
with block("Assertion: " + msg):
reg.u32('asrt_base asrt_off')
op.mov.u32(asrt_base, g_std_exit_err)
self.get_gtid(asrt_off)
op.mad.lo.u32(asrt_base, asrt_off, 4, asrt_base)
realp = None
if o:
realp = self.not_(reg.pred('p_asrt_fail'))
if a is None or b is None:
raise ValueError("Must specify ``a`` and ``b`` with ``o``.")
op._call(['setp.'+o], p_asrt_fail, a, b)
if p:
realp = self.not_(p)
if notp:
realp = notp
op.st.global_.u32(addr(asrt_base), err_code, ifp=realp)
if not ign:
if ret:
op.ret(ifp=realp)
else:
op.trap(ifp=realp)
@ptx_func
def entry_setup(self):
self.asrt("Unexpected thread exit", ign=True, lvl=0)
@ptx_func
def entry_teardown(self):
self.asrt(self.asserts[0], ret=True, lvl=0)
def call_teardown(self, ctx):
"""
This function raises an exception if all cleanup code wasn't called on
the device. To suppress this - for instance, to inspect data from a
partially-executed thread - do
>>> std.asrt(std.asserts[0], ign=True, lvl=0)
at the start of your entry. Yes, it's a hacky solution.
"""
dp, l = ctx.mod.get_global('g_std_exit_err')
errs = cuda.from_device(dp, ctx.threads, np.uint32)
if np.sum(errs) != 0:
print "Some threads terminated unsuccessfully."
for i, msg in enumerate(self.asserts):
count = sum(np.equal(errs, i))
if count:
print '%6d said "%s".' % (count, msg)
print
raise EnvironmentError("Abnormal thread termination")
def to_inject(self):
# Set up the initial namespace
return dict(