Use pycuda SourceModule to work around crashes, and a few invocation touchups.

This commit is contained in:
Steven Robertson 2010-09-10 18:02:37 -04:00
parent c3d12d07c2
commit 943e92b80c
2 changed files with 17 additions and 12 deletions

View File

@ -4,6 +4,7 @@ import pyglet.gl as gl
gl.get_current_context() gl.get_current_context()
import pycuda.driver as cuda import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import pycuda.tools import pycuda.tools
import pycuda.gl as cudagl import pycuda.gl as cudagl
import pycuda.gl.autoinit import pycuda.gl.autoinit
@ -57,15 +58,19 @@ class LaunchContext(object):
def compile(self, verbose=False, **kwargs): def compile(self, verbose=False, **kwargs):
kwargs['ctx'] = self kwargs['ctx'] = self
self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests) self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests)
# TODO: make this optional and let user choose path
with open('/tmp/cuburn.ptx', 'w') as f: f.write(self.ptx.source)
try: try:
self.mod = cuda.module_from_buffer(self.ptx.source) # TODO: detect/customize arch, code; verbose setting;
# keep directory enable/disable via debug
self.mod = SourceModule(self.ptx.source, no_extern_c=True,
options=['--keep', '-v', '-G'])
except (cuda.CompileError, cuda.RuntimeError), e: except (cuda.CompileError, cuda.RuntimeError), e:
print "Aww, dang, compile error. Here's the source:" # TODO: if output not written above, print different message
self.ptx.print_source() print "Compile error. Source is at /tmp/cuburn.ptx"
print e
raise e raise e
if verbose: if verbose:
if verbose >= 3:
self.ptx.print_source()
for entry in self.ptx.entries: for entry in self.ptx.entries:
func = self.mod.get_function(entry.entry_name) func = self.mod.get_function(entry.entry_name)
print "Compiled %s: used %d regs, %d sm, %d local" % ( print "Compiled %s: used %d regs, %d sm, %d local" % (
@ -83,7 +88,9 @@ class LaunchContext(object):
try: try:
inst.call_teardown(self) inst.call_teardown(self)
except PTXTestFailure, e: except PTXTestFailure, e:
print "PTX Test %s failed!" % inst.entry_name, e print "\nTest %s FAILED!" % inst.entry_name
print "Reason:", e
print
okay = False okay = False
else: else:
inst.call_teardown(self) inst.call_teardown(self)
@ -99,7 +106,6 @@ class LaunchContext(object):
if test.call(self): if test.call(self):
print "Test %s passed." % test.entry_name print "Test %s passed." % test.entry_name
else: else:
print "Test %s FAILED." % test.entry_name
all_okay = False all_okay = False
return all_okay return all_okay

View File

@ -628,7 +628,7 @@ class PTXEntryPoint(PTXFragment):
Override this if you need to change how a function is called. Override this if you need to change how a function is called.
""" """
# TODO: global debugging / verbosity # TODO: global debugging / verbosity
print "Invoking PTX function '%s' on device" % self.entry_name print "\nInvoking PTX function '%s' on device" % self.entry_name
kwargs.setdefault('block', ctx.block) kwargs.setdefault('block', ctx.block)
kwargs.setdefault('grid', ctx.grid) kwargs.setdefault('grid', ctx.grid)
dtime = func(time_kernel=True, *args, **kwargs) dtime = func(time_kernel=True, *args, **kwargs)
@ -653,8 +653,7 @@ class PTXTest(PTXEntryPoint):
* The active context will be synchronized before each call, * The active context will be synchronized before each call,
* call_teardown() should raise ``PTXTestFailure`` if a test failed. * call_teardown() should raise ``PTXTestFailure`` if a test failed.
This exception will be caught and cleanup will be completed This exception will be caught and cleanup will be completed.
(unless another exception is raised).
""" """
pass pass
@ -673,7 +672,7 @@ class _PTXStdLib(PTXFragment):
# TODO: make this modular, maybe? of course, we'd have to support # TODO: make this modular, maybe? of course, we'd have to support
# multiple devices first, which we definitely do not yet do # multiple devices first, which we definitely do not yet do
self.block.code(prefix='.version 2.1', semi=False) self.block.code(prefix='.version 2.1', semi=False)
self.block.code(prefix='.target sm_20', semi=False) self.block.code(prefix='.target sm_21', semi=False)
@ptx_func @ptx_func
def get_gtid(self, dst): def get_gtid(self, dst):