mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Add device assertions to standard library.
This commit is contained in:
parent
3932412539
commit
56404b629f
@ -579,22 +579,22 @@ class MWCRNGFloatsTest(PTXTest):
|
||||
@ptx_func
|
||||
def loop(self, kind):
|
||||
with block('Sum %d floats in %s' % (self.rounds, kind)):
|
||||
reg.f32('loopct val sum rmin rmax')
|
||||
reg.f32('loopct val rsum rmin rmax')
|
||||
reg.pred('p_done')
|
||||
op.mov.f32(loopct, 0.)
|
||||
op.mov.f32(sum, 0.)
|
||||
op.mov.f32(rsum, 0.)
|
||||
op.mov.f32(rmin, 2.)
|
||||
op.mov.f32(rmax, -2.)
|
||||
label('loopstart' + kind)
|
||||
getattr(mwc, 'next_f32_' + kind)(val)
|
||||
op.add.f32(sum, sum, val)
|
||||
op.add.f32(rsum, rsum, val)
|
||||
op.min.f32(rmin, rmin, val)
|
||||
op.max.f32(rmax, rmax, val)
|
||||
op.add.f32(loopct, loopct, 1.)
|
||||
op.setp.ge.f32(p_done, loopct, float(self.rounds))
|
||||
op.bra('loopstart' + kind, ifnotp=p_done)
|
||||
op.mul.f32(sum, sum, 1./self.rounds)
|
||||
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, sum,
|
||||
op.mul.f32(rsum, rsum, 1./self.rounds)
|
||||
std.store_per_thread('mwc_rng_float_%s_test_sums' % kind, rsum,
|
||||
'mwc_rng_float_%s_test_mins' % kind, rmin,
|
||||
'mwc_rng_float_%s_test_maxs' % kind, rmax)
|
||||
|
||||
|
111
cuburn/ptx.py
111
cuburn/ptx.py
@ -16,6 +16,9 @@ from cStringIO import StringIO
|
||||
from collections import namedtuple
|
||||
from math import *
|
||||
|
||||
import numpy as np
|
||||
import pycuda.driver as cuda
|
||||
|
||||
# Okay, so here's what's going on.
|
||||
#
|
||||
# We're using Python to create PTX. If we just use Python to make one giant PTX
|
||||
@ -642,8 +645,10 @@ class PTXEntryPoint(PTXFragment):
|
||||
"""
|
||||
ctx.call_setup(self)
|
||||
func = ctx.mod.get_function(self.entry_name)
|
||||
self._call(ctx, func, *args, **kwargs)
|
||||
return ctx.call_teardown(self)
|
||||
try:
|
||||
self._call(ctx, func, *args, **kwargs)
|
||||
finally:
|
||||
return ctx.call_teardown(self)
|
||||
|
||||
class PTXTestFailure(Exception): pass
|
||||
|
||||
@ -663,6 +668,7 @@ class _PTXStdLib(PTXFragment):
|
||||
def __init__(self, block):
|
||||
# Only module that gets the privilege of seeing 'block' directly.
|
||||
self.block = block
|
||||
self.asserts = ["Success"]
|
||||
|
||||
def deps(self):
|
||||
return []
|
||||
@ -673,6 +679,7 @@ class _PTXStdLib(PTXFragment):
|
||||
# multiple devices first, which we definitely do not yet do
|
||||
self.block.code(prefix='.version 2.1', semi=False)
|
||||
self.block.code(prefix='.target sm_21', semi=False)
|
||||
mem.global_.u32('g_std_exit_err', ctx.threads)
|
||||
|
||||
@ptx_func
|
||||
def get_gtid(self, dst):
|
||||
@ -716,6 +723,106 @@ class _PTXStdLib(PTXFragment):
|
||||
def not_(self, pred):
|
||||
return ['!', pred]
|
||||
|
||||
@ptx_func
|
||||
def asrt(self, msg, o=None, a=None, b=None, p=None, notp=None,
|
||||
ret=False, ign=False, lvl=1):
|
||||
"""
|
||||
Device assertion.
|
||||
|
||||
Without arguments, a thread will log the error code associated with
|
||||
``msg`` and issue a trap instruction, which will cause the device to
|
||||
terminate execution in all threads immediately. Any of the options
|
||||
below modify that behavior, as described.
|
||||
|
||||
``o``, ``a`` and ``b``, when set together, will be used to create a
|
||||
``setp`` instruction to test a condition. They're the first three
|
||||
arguments, to make usage a bit more natural:
|
||||
|
||||
>>> std.asrt('lt.u32', val, 0)
|
||||
|
||||
This would generate the instruction ``setp.lt.u32 <p>, val, 0;``
|
||||
(<p> is created by this function). The thread would only store the
|
||||
error code and exit if the condition were *false*.
|
||||
|
||||
``p`` is a predicate value; the store and trap will happen if it is
|
||||
*not* set (same sense as ``o`` and Python's assert). ``notp`` is the
|
||||
reverse.
|
||||
|
||||
Only one of ``o``, ``ifp``, or ``ifnotp`` can be set per call.
|
||||
|
||||
``ret`` causes the assert to issue a ``ret;`` instruction in place of
|
||||
the trap. This causes the current thread to terminate, but does not
|
||||
cause the other threads to do so. Be cautious, as barriers can cause a
|
||||
kernel to hang using this instruction.
|
||||
|
||||
``ign`` causes the error code to be stored, but does not terminate
|
||||
thread execution ("ignores" the error). This is useful to identify the
|
||||
location of all threads in case of an abnormal termination caused by
|
||||
another thread, and is used to set up the entry-wide "early
|
||||
termination" error. ``ign`` overrides ``ret``.
|
||||
|
||||
This code calculates the gtid unconditionally, and so can be relatively
|
||||
expensive to insert into a tight loop. As a result, assert
|
||||
statements will only be added if the debug value ``assert_level`` is
|
||||
at least as large as the ``lvl`` argument.
|
||||
"""
|
||||
# TODO: debug level checking
|
||||
if np.sum(map(bool, (o, p, notp))) > 1:
|
||||
raise ValueError("Can only use one of o, ifp, ifnotp.")
|
||||
if msg not in self.asserts:
|
||||
self.asserts.append(msg)
|
||||
err_code = self.asserts.index(msg)
|
||||
with block("Assertion: " + msg):
|
||||
reg.u32('asrt_base asrt_off')
|
||||
op.mov.u32(asrt_base, g_std_exit_err)
|
||||
self.get_gtid(asrt_off)
|
||||
op.mad.lo.u32(asrt_base, asrt_off, 4, asrt_base)
|
||||
realp = None
|
||||
if o:
|
||||
realp = self.not_(reg.pred('p_asrt_fail'))
|
||||
if a is None or b is None:
|
||||
raise ValueError("Must specify ``a`` and ``b`` with ``o``.")
|
||||
op._call(['setp.'+o], p_asrt_fail, a, b)
|
||||
if p:
|
||||
realp = self.not_(p)
|
||||
if notp:
|
||||
realp = notp
|
||||
op.st.global_.u32(addr(asrt_base), err_code, ifp=realp)
|
||||
if not ign:
|
||||
if ret:
|
||||
op.ret(ifp=realp)
|
||||
else:
|
||||
op.trap(ifp=realp)
|
||||
|
||||
@ptx_func
|
||||
def entry_setup(self):
|
||||
self.asrt("Unexpected thread exit", ign=True, lvl=0)
|
||||
|
||||
@ptx_func
|
||||
def entry_teardown(self):
|
||||
self.asrt(self.asserts[0], ret=True, lvl=0)
|
||||
|
||||
def call_teardown(self, ctx):
|
||||
"""
|
||||
This function raises an exception if all cleanup code wasn't called on
|
||||
the device. To suppress this - for instance, to inspect data from a
|
||||
partially-executed thread - do
|
||||
|
||||
>>> std.asrt(std.asserts[0], ign=True, lvl=0)
|
||||
|
||||
at the start of your entry. Yes, it's a hacky solution.
|
||||
"""
|
||||
dp, l = ctx.mod.get_global('g_std_exit_err')
|
||||
errs = cuda.from_device(dp, ctx.threads, np.uint32)
|
||||
if np.sum(errs) != 0:
|
||||
print "Some threads terminated unsuccessfully."
|
||||
for i, msg in enumerate(self.asserts):
|
||||
count = sum(np.equal(errs, i))
|
||||
if count:
|
||||
print '%6d said "%s".' % (count, msg)
|
||||
print
|
||||
raise EnvironmentError("Abnormal thread termination")
|
||||
|
||||
def to_inject(self):
|
||||
# Set up the initial namespace
|
||||
return dict(
|
||||
|
Loading…
Reference in New Issue
Block a user