Use pycuda SourceModule to work around crashes, and a few invocation touchups.

This commit is contained in:
Steven Robertson 2010-09-10 18:02:37 -04:00
parent c3d12d07c2
commit 943e92b80c
2 changed files with 17 additions and 12 deletions

View File

@ -4,6 +4,7 @@ import pyglet.gl as gl
gl.get_current_context()
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import pycuda.tools
import pycuda.gl as cudagl
import pycuda.gl.autoinit
@ -57,15 +58,19 @@ class LaunchContext(object):
def compile(self, verbose=False, **kwargs):
kwargs['ctx'] = self
self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests)
# TODO: make this optional and let user choose path
with open('/tmp/cuburn.ptx', 'w') as f: f.write(self.ptx.source)
try:
self.mod = cuda.module_from_buffer(self.ptx.source)
# TODO: detect/customize arch, code; verbose setting;
# keep directory enable/disable via debug
self.mod = SourceModule(self.ptx.source, no_extern_c=True,
options=['--keep', '-v', '-G'])
except (cuda.CompileError, cuda.RuntimeError), e:
print "Aww, dang, compile error. Here's the source:"
self.ptx.print_source()
# TODO: if output not written above, print different message
print "Compile error. Source is at /tmp/cuburn.ptx"
print e
raise e
if verbose:
if verbose >= 3:
self.ptx.print_source()
for entry in self.ptx.entries:
func = self.mod.get_function(entry.entry_name)
print "Compiled %s: used %d regs, %d sm, %d local" % (
@ -83,7 +88,9 @@ class LaunchContext(object):
try:
inst.call_teardown(self)
except PTXTestFailure, e:
print "PTX Test %s failed!" % inst.entry_name, e
print "\nTest %s FAILED!" % inst.entry_name
print "Reason:", e
print
okay = False
else:
inst.call_teardown(self)
@ -99,7 +106,6 @@ class LaunchContext(object):
if test.call(self):
print "Test %s passed." % test.entry_name
else:
print "Test %s FAILED." % test.entry_name
all_okay = False
all_okay = False
return all_okay

View File

@ -628,7 +628,7 @@ class PTXEntryPoint(PTXFragment):
Override this if you need to change how a function is called.
"""
# TODO: global debugging / verbosity
print "Invoking PTX function '%s' on device" % self.entry_name
print "\nInvoking PTX function '%s' on device" % self.entry_name
kwargs.setdefault('block', ctx.block)
kwargs.setdefault('grid', ctx.grid)
dtime = func(time_kernel=True, *args, **kwargs)
@ -653,8 +653,7 @@ class PTXTest(PTXEntryPoint):
* The active context will be synchronized before each call,
* call_teardown() should raise ``PTXTestFailure`` if a test failed.
This exception will be caught and cleanup will be completed
(unless another exception is raised).
This exception will be caught and cleanup will be completed.
"""
pass
@ -673,7 +672,7 @@ class _PTXStdLib(PTXFragment):
# TODO: make this modular, maybe? of course, we'd have to support
# multiple devices first, which we definitely do not yet do
self.block.code(prefix='.version 2.1', semi=False)
self.block.code(prefix='.target sm_20', semi=False)
self.block.code(prefix='.target sm_21', semi=False)
@ptx_func
def get_gtid(self, dst):