PTX DSL working, at least well enough to pass MWCRNGTest

This commit is contained in:
Steven Robertson 2010-09-01 21:09:40 -04:00
parent 5f8c2bbf08
commit a3660ec6e4
4 changed files with 385 additions and 228 deletions

View File

@ -10,7 +10,7 @@ import pycuda.gl.autoinit
import numpy as np
from cuburnlib.ptx import PTXAssembler
from cuburnlib.ptx import PTXModule
class LaunchContext(object):
"""
@ -44,8 +44,10 @@ class LaunchContext(object):
def threads(self):
return reduce(lambda a, b: a*b, self.block + self.grid)
def compile(self, verbose=False):
self.ptx = PTXAssembler(self, self.entry_types, self.build_tests)
def compile(self, to_inject={}, verbose=False):
inj = dict(to_inject)
inj['ctx'] = self
self.ptx = PTXModule(self.entry_types, inj, self.build_tests)
try:
self.mod = cuda.module_from_buffer(self.ptx.source)
except (cuda.CompileError, cuda.RuntimeError), e:
@ -54,15 +56,16 @@ class LaunchContext(object):
enumerate(self.ptx.source.split('\n'))])
raise e
if verbose:
for name in self.ptx.entry_names.values():
func = self.mod.get_function(name)
print "Compiled %s: used %d regs, %d sm, %d local" % (func,
func.num_regs, func.shared_size_bytes, func.local_size_bytes)
for entry in self.ptx.entries:
func = self.mod.get_function(entry.entry_name)
print "Compiled %s: used %d regs, %d sm, %d local" % (
entry.entry_name, func.num_regs,
func.shared_size_bytes, func.local_size_bytes)
def set_up(self):
for inst in self.ptx.deporder(self.ptx.instances.values(),
self.ptx.instances, self):
inst.set_up(self)
self.ptx.instances):
inst.device_init(self)
def run(self):
if not self.setup_done: self.set_up()

View File

@ -8,7 +8,7 @@ import time
import pycuda.driver as cuda
import numpy as np
from cuburnlib.ptx import PTXFragment, PTXEntryPoint, PTXTest
from cuburnlib.ptx import *
"""
Here's the current draft of the full algorithm implementation.
@ -113,10 +113,12 @@ class MWCRNG(PTXFragment):
if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found')
@ptx_func
def module_setup(self):
mem.global_.u32('mwc_rng_mults', ctx.threads)
mem.global_.u32('mwc_rng_state', ctx.threads)
mem.global_.u64('mwc_rng_state', ctx.threads)
@ptx_func
def entry_setup(self):
reg.u32('mwc_st mwc_mult mwc_car')
with block('Load MWC multipliers and states'):
@ -130,6 +132,7 @@ class MWCRNG(PTXFragment):
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr))
@ptx_func
def entry_teardown(self):
with block('Save MWC states'):
reg.u32('mwc_off mwc_addr')
@ -138,15 +141,19 @@ class MWCRNG(PTXFragment):
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
@ptx_func
def next_b32(self, dst_reg):
with block('Load next random into ' + dst_reg.name):
reg.u64('mwc_out')
op.cvt.u64.u32(mwc_out, mwc_car)
mad.wide.u32(mwc_out, mwc_st)
mov.b64(vec(mwc_st, mwc_car), mwc_out)
mov.u32(dst_reg, mwc_st)
op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out)
op.mov.b64(vec(mwc_st, mwc_car), mwc_out)
op.mov.u32(dst_reg, mwc_st)
def set_up(self, ctx):
def to_inject(self):
return dict(mwc_next_b32=self.next_b32)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads:
# Already set up enough random states, don't push again
return
@ -168,21 +175,25 @@ class MWCRNG(PTXFragment):
states = np.array(ctx.rand.randint(1, 0xffffffff, size=2*ctx.threads),
dtype=np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
print states, len(states.tostring())
cuda.memcpy_htod_async(statedp, states.tostring())
self.threads_ready = ctx.threads
def tests(self, ctx):
def tests(self):
return [MWCRNGTest]
class MWCRNGTest(PTXTest):
name = "MWC RNG sum-of-threads"
deps = [MWCRNG]
rounds = 10000
entry_name = 'MWC_RNG_test'
entry_params = ''
def deps(self):
return [MWCRNG]
@ptx_func
def module_setup(self):
mem.global_.u64(mwc_rng_test_sums, ctx.threads)
mem.global_.u64('mwc_rng_test_sums', ctx.threads)
@ptx_func
def entry(self):
@ -191,7 +202,7 @@ class MWCRNGTest(PTXTest):
op.mov.u64(sum, 0)
with block('Sum next %d random numbers' % self.rounds):
reg.u32('loopct')
pred('p')
reg.pred('p')
op.mov.u32(loopct, self.rounds)
label('loopstart')
mwc_next_b32(addend)
@ -206,7 +217,7 @@ class MWCRNGTest(PTXTest):
get_gtid(offset)
op.mov.u32(adr, mwc_rng_test_sums)
op.mad.lo.u32(adr, offset, 8, adr)
st.global_.u64(addr(adr), sum)
op.st.global_.u64(addr(adr), sum)
def call(self, ctx):
# Get current multipliers and seeds from the device

View File

@ -10,7 +10,8 @@ easier to maintain using this system.
# If you see 'import inspect', you know you're in for a good time
import inspect
import ctypes
import types
import traceback
from collections import namedtuple
# Okay, so here's what's going on.
@ -23,7 +24,7 @@ from collections import namedtuple
# splitting things up at the level of PTX will greatly reduce performance, as
# the cost of accessing the stack, spilling registers, and reloading data from
# system memory is unacceptably high even on Fermi GPUs. So we want to split
# code up into functions within Python, but not within the PTX.
# code up into functions within Python, but not within the PTX source.
#
# The challenge here is variable lifetime. A PTX function might declare a
# register at the top of the main block and use it several times throughout the
@ -50,10 +51,10 @@ from collections import namedtuple
# reg.u32('hooray_reg')
# load_zero(hooray_reg)
#
# But using blocks to track state, it would turn in to this ugliness::
# But using blocks alone to track names, it would turn in to this ugliness::
#
# def load_zero(block, dest_reg):
# block.op.mov.u32(op.dest_reg, 0)
# block.op.mov.u32(block.op.dest_reg, 0)
# def init_module():
# with Block() as block:
# block.regs.hooray_reg = block.reg.u32('hooray_reg')
@ -70,6 +71,9 @@ from collections import namedtuple
# below give a clear picture of how to use it, but now you know why this
# abomination was crafted to begin with.
def _softjoin(args, sep):
"""Intersperses 'sep' between 'args' without coercing to string."""
return reduce(lambda l, x: l + [x, sep], args, [])[:-1]
BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
@ -100,6 +104,13 @@ class _BlockInjector(object):
else:
self.inject_into[k] = v
self.injected.add(k)
def pop(self, keys):
"""Remove keys from a dictionary, as long as we added them."""
assert not self.dead
for k in keys:
if k in self.injected:
self.inject_into.pop(k)
self.injected.remove(k)
def __enter__(self):
self.dead = False
map(self.inject, self.to_inject.items())
@ -115,22 +126,47 @@ class _Block(object):
For important reasons, the instance must be bound locally as "_block".
"""
name = '_block'
name = '_block' # For retrieving from parent scope on first call
def __init__(self):
self.reset()
def reset(self):
self.outer_ctx = BlockCtx({self.name: self}, [], [])
self.stack = [self.outer_ctx]
def clean_injectors(self):
inj = self.stack[-1].injectors
[inj.remove(i) for i in inj if i.dead]
def push_ctx(self):
self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], []))
# Move most recent active injector to new context
self.clean_injectors()
last_inj = self.stack[-1].injectors.pop()
self.stack.append(BlockCtx(dict(self.stack[-1].locals), [],
[last_inj]))
def pop_ctx(self):
self.clean_injectors()
bs = self.stack.pop()
self.stack[-1].code.append(bs.code)
self.stack[-1].code.extend(bs.code)
if len(self.stack) == 1:
# We're on outer_ctx, so all injectors should be gone
assert len(bs.injectors) == 0, "Injector/context mismatch"
return
# The only injector should be the one added in push_ctx
assert len(bs.injectors) == 1, "Injector/context mismatch"
# Find out which keys were injected while in this context
diff = set(bs.locals.keys()).difference(
set(self.stack[-1].locals.keys()))
# Pop keys and move current injector back down to last context
last_inj = bs.injectors.pop()
last_inj.pop(diff)
self.stack[-1].injectors.append(last_inj)
def injector(self, func_globals):
inj = BlockInjector(self.stack[-1].locals, func_globals)
inj = _BlockInjector(dict(self.stack[-1].locals), func_globals)
self.stack[-1].injectors.append(inj)
return inj
def inject(self, name, object):
if name in self.stack[-1].locals:
raise KeyError("Duplicate name already exists in this scope.")
if self.stack[-1].locals[name] is not object:
raise KeyError("'%s' already exists in this scope." % name)
else:
self.stack[-1].locals[name] = object
[inj.inject(name, object) for inj in self.stack[-1].injectors]
def code(self, prefix='', op='', vars=[], semi=True, indent=0):
@ -157,7 +193,7 @@ class _Block(object):
yes, the only real difference between `prefix`, `op`, and `vars` is in
final appearance, but it is in fact quite helpful for debugging.
"""
self.stack[-1].append(PTXStmt(prefix, op, vars, indent))
self.stack[-1].code.append(PTXStmt(prefix, op, vars, semi, indent))
class StrVar(object):
"""
@ -168,28 +204,50 @@ class StrVar(object):
def __str__(self):
return str(val)
class _PTXFuncWrapper(object):
"""Enables ptx_func"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
if _Block.name in globals():
block = globals()['block']
else:
# Find the '_block' from the enclosing scope
parent = inspect.stack()[2][0]
if _Block.name in parent.f_locals:
block = parent.f_locals[_Block.name]
elif _Block.name in parent.f_globals:
block = parent.f_globals[_Block.name]
else:
# Couldn't find the _block instance. Fail cryptically to
# encourage users to read the source (for now)
raise SyntaxError("Black magic")
# Create a new function with the modified scope and call it. We could
# do this in __init__, but it would hide any changes to globals from
# the module's original scope. Still an option if performance sucks.
newglobals = dict(self.func.func_globals)
func = types.FunctionType(self.func.func_code, newglobals,
self.func.func_name, self.func.func_defaults,
self.func.func_closure)
# TODO: if we generate a new dict every time, we can kill the
# _BlockInjector and move BI.inject() back to _Block, but I don't want
# to delete working code just yet
with block.injector(func.func_globals):
func(*args, **kwargs)
def ptx_func(func):
"""
Decorator function for code in the DSL. Any function which accesses the DSL
namespace, including declared device variables and objects such as "reg"
or "op", should be wrapped with this. See Block for some examples.
Note that writes to global variables will silently fail for now.
"""
def ptx_eval(*args, **kwargs):
if self.name not in globals():
parent = inspect.stack()[-2][0]
if self.name in parent.f_locals:
block = parent.f_locals[self.name]
elif self.name in parent.f_globals:
block = parent.f_globals[self.name]
else:
# Couldn't find the _block instance. Fail cryptically to
# encourage users to read the source (for now)
raise SyntaxError("Black magic")
else:
block = globals()['block']
with block.injector(func.func_globals):
func(*args, **kwargs)
return ptx_eval
# Attach most of the code to the wrapper class
fw = _PTXFuncWrapper(func)
def wr(*args, **kwargs):
fw(*args, **kwargs)
return wr
class Block(object):
"""
@ -231,14 +289,17 @@ class Block(object):
# `block` is the real _block
self.block = block
self.comment = None
def __call__(self, comment=None)
def __call__(self, comment=None):
self.comment = comment
return self
def __enter__(self):
self.block.push_ctx()
self.block.code(op='{', indent=4)
self.block.code(op='{', indent=1, semi=False)
if self.comment:
self.block.code(op=['// ', self.comment], semi=False)
self.comment = None
def __exit__(self, exc_type, exc_value, tb):
self.block.code(op='}', indent=-4)
self.block.code(op='}', indent=-1, semi=False)
self.block.pop_ctx()
class _CallChain(object):
@ -248,12 +309,12 @@ class _CallChain(object):
self.__chain = []
def __call__(self, *args, **kwargs):
assert(self.__chain)
self._call(chain, *args, **kwargs)
self._call(self.__chain, *args, **kwargs)
self.__chain = []
def __getattr__(self, name):
if name == 'global_':
name = 'global'
self.chain.append(name)
self.__chain.append(name)
# Another great crime against the universe:
return self
@ -284,7 +345,7 @@ class _RegFactory(_CallChain):
type = type[0]
names = names.split()
regs = map(lambda n: Reg(type, n), names)
self.block.code(op='.reg .' + type, vars=names)
self.block.code(op='.reg .' + type, vars=_softjoin(names, ', '))
[self.block.inject(r.name, r) for r in regs]
# Pending resolution of the op(regs, guard=x) debate
@ -318,8 +379,6 @@ class _RegFactory(_CallChain):
#self.name = name
#def is_set(self, isnot=False):
class Op(_CallChain):
"""
Performs an operation.
@ -340,15 +399,15 @@ class Op(_CallChain):
This constructor is available as 'op' in DSL blocks.
"""
def _call(self, op, *args, ifp=None, ifnotp=None):
def _call(self, op, *args, **kwargs):
pred = ''
if ifp:
if ifnotp:
if 'ifp' in kwargs:
if 'ifnotp' in kwargs:
raise SyntaxError("can't use both, fool")
pred = ['@', ifp]
if ifnotp:
pred = ['@!', ifnotp]
self.block.append_code(pred, '.'.join(op), map(str, args))
pred = ['@', kwargs['ifp']]
if 'ifnotp' in kwargs:
pred = ['@!', kwargs['ifnotp']]
self.block.code(pred, '.'.join(op), _softjoin(args, ', '))
class Mem(object):
"""
@ -381,7 +440,7 @@ class Mem(object):
>>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
"""
return ['{', [(a, ', ') for a in args][:-1], '}']
return ['{', _softjoin(args, ', '), '}']
@staticmethod
def addr(areg, aoffset=''):
@ -397,8 +456,7 @@ class _MemFactory(_CallChain):
"""Actual `mem` object"""
def _call(self, type, name, array=False, initializer=None):
assert len(type) == 2
memobj = Mem(type, name, array)
self.dsl.inject(name, memobj)
memobj = Mem(type, name, array, initializer)
if array is True:
array = ['[]']
elif array:
@ -407,11 +465,12 @@ class _MemFactory(_CallChain):
array = []
if initializer:
array += [' = ', initializer]
self.block.code(op=['.%s.%s ' % type, name, array])
self.block.code(op=['.%s.%s ' % (type[0], type[1]), name, array])
self.block.inject(name, memobj)
class Label(object):
"""
Specifies the target for a branch. Scoped in PTX? TODO: test.
Specifies the target for a branch. Scoped in PTX? TODO: test that it is.
>>> label('infinite_loop')
>>> op.bra.uni('label')
@ -426,25 +485,7 @@ class _LabelFactory(object):
self.block = block
def __call__(self, name):
self.block.inject(name, Label(name))
class PTXFragment(object):
def module_setup(self):
pass
def entry_setup(self):
pass
def entry_teardown(self):
pass
def globals(self):
pass
def tests(self):
pass
def device_init(self, ctx):
pass
self.block.code(prefix='%s:' % name, semi=False)
class PTXFragment(object):
"""
@ -470,12 +511,14 @@ class PTXFragment(object):
for successful compilation. Circular dependencies are forbidden,
but multi-level dependencies should be fine.
"""
return [DeviceHelpers]
return [_PTXStdLib]
def inject(self):
def to_inject(self):
"""
Returns a dict of items to add to the DSL namespace. The namespace will
be assembled in dependency order before any ptx_funcs are called.
This is only called once per PTXModule (== once per instance).
"""
return {}
@ -483,8 +526,8 @@ class PTXFragment(object):
"""
PTX function to declare things at module scope. It's a PTX syntax error
to perform operations at this scope, but we don't yet validate that at
the Python level. A module will call this function on all fragments in
dependency order.
the Python level. A module will call this function on all fragments
used in that module in dependency order.
If implemented, this function should use an @ptx_func decorator.
"""
@ -513,131 +556,47 @@ class PTXFragment(object):
"""
pass
def tests(self, ctx):
def finalize_code(self):
"""
Returns a list of PTXTest classes which will test this fragment.
Called after running all PTX DSL functions, but before code generation,
to allow fragments which postponed variable evaluation (e.g. using
`StrVar`) to fill in the resulting values. Most fragments should not
use this.
If implemented, this function *may* use an @ptx_func decorator to
access the global DSL scope, but pretty please don't emit any code
while you're in there.
"""
pass
def tests(self):
"""
Returns a list of PTXTest types which will test this fragment.
"""
return []
def set_up(self, ctx):
def device_init(self, ctx):
"""
Do start-of-stream initialization, such as copying data to the device.
Do stuff on the host to prepare the device for execution. 'ctx' is a
LaunchContext or similar. This will get called (in dependency order, of
course) *either* before any entry point invocation, or before *each*
invocation, I'm not sure which yet. (For now it's "each".)
"""
pass
class PTXModule(object):
"""
Assembles PTX fragments into a module.
"""
def __init__(self, entries, inject={}, build_tests=False):
self._block = b = _Block()
self.initial_inject = dict(inject)
self._safeupdate(self.initial_inject, dict(block=Block(b),
mem=_MemFactory(b), reg=_RegFactory(b), op=Op(b),
label=_LabelFactory(b), _block=b)
self.needs_recompilation = True
self.max_compiles = 10
while self.needs_recompilation:
self.assemble(entries, build_tests)
self.max_compiles -= 1
def deporder(self, unsorted_instances, instance_map, ctx):
"""
Do a DFS on PTXFragment dependencies, and return an ordered list of
instances where no fragment depends on any before it in the list.
`unsorted_instances` is the list of instances to sort.
`instance_map` is a dict of types to instances.
"""
seen = {}
def rec(inst):
if inst in seen: return seen[inst]
deps = filter(lambda d: d is not inst, map(instance_map.get,
callable(inst.deps) and inst.deps(self) or inst.deps))
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
"""dst.update(src), but no duplicates allowed"""
non_uniq = [k for k in src if k in dst]
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
dst.update(src)
def assemble(self, entries, build_tests):
"""
Build the PTX source for the given set of entries.
"""
# Get a property, dealing with the callable-or-data thing. This is
# cumbersome, but flexible; when finished, it may be simplified.
def pget(prop):
if callable(prop): return prop(ctx)
return prop
instances = {}
unvisited_entries = list(entries)
entry_names = {}
tests = []
parsed_entries = []
while unvisited_entries:
ent = unvisited_entries.pop(0)
seen, unvisited = set(), [ent]
while unvisited:
frag = unvisited.pop(0)
seen.add(frag)
inst = instances.setdefault(frag, frag())
for dep in pget(inst.deps):
if dep not in seen:
unvisited.append(dep)
if build_tests:
for test in pget(inst.tests):
if test not in tests:
if test not in instances:
unvisited_entries.append(test)
tests.append(test)
tmpl_namespace = {'ctx': ctx}
entry_start, entry_end = [], []
for inst in self.deporder(map(instances.get, seen), instances, ctx):
self._safeupdate(tmpl_namespace, pget(inst.subs))
entry_start.append(pget(inst.entry_start))
entry_end.append(pget(inst.entry_end))
entry_start_tmpl = '\n'.join(filter(None, entry_start))
entry_end_tmpl = '\n'.join(filter(None, reversed(entry_end)))
name, args, body = pget(instances[ent].entry)
tmpl_namespace.update({'_entry_name_': name, '_entry_args_': args,
'_entry_body_': body, '_entry_start_': entry_start_tmpl,
'_entry_end_': entry_end_tmpl})
entry_tmpl = (".entry {{ _entry_name_ }} ({{ _entry_args_ }})\n"
"{\n{{_entry_start_}}\n{{_entry_body_}}\n{{_entry_end_}}\n}\n")
parsed_entries.append(multisub(entry_tmpl, tmpl_namespace))
entry_names[ent] = name
prelude = []
tmpl_namespace = {'ctx': ctx}
for inst in self.deporder(instances.values(), instances, ctx):
prelude.append(pget(inst.prelude))
self._safeupdate(tmpl_namespace, pget(inst.subs))
tmpl_namespace['_prelude_'] = '\n'.join(filter(None, prelude))
tmpl_namespace['_entries_'] = '\n\n'.join(parsed_entries)
tmpl = "{{ _prelude_ }}\n{{ _entries_ }}"
self.entry_names = entry_names
self.source = ppr_ptx(multisub(tmpl, tmpl_namespace))
self.instances = instances
self.tests = tests
class PTXEntryPoint(PTXFragment):
# Human-readable entry point name
name = ""
# Device code entry name
entry_name = ""
# List of (type, name) pairs for entry params, e.g. [('u32', 'thing')]
entry_params = []
def entry(self, ctx):
"""
Returns a 3-tuple of (name, args, body), which will be assembled into
a function.
PTX DSL function that comprises the body of the PTX statement.
Must be implemented and decorated with ptx_func.
"""
raise NotImplementedError
@ -660,32 +619,216 @@ class PTXTest(PTXEntryPoint):
"""
pass
class DeviceHelpers(PTXFragment):
def __init__(self):
self._forstack = []
class _PTXStdLib(PTXFragment):
def __init__(self, block):
# Only module that gets the privilege of seeing 'block' directly.
self.block = block
prelude = ".version 2.1\n.target sm_20\n\n"
def deps(self):
return []
@ptx_func
def module_setup(self):
# TODO: make this modular, maybe? of course, we'd have to support
# multiple devices first, which we definitely do not yet do
self.block.code(prefix='.version 2.1', semi=False)
self.block.code(prefix='.target sm_20', semi=False)
@ptx_func
def _get_gtid(self, dst):
return "{\n// Load GTID into " + dst + """
.reg .u16 tmp;
.reg .u32 cta, ncta, tid, gtid;
with block("Load GTID into %s" % str(dst)):
reg.u16('tmp')
reg.u32('cta ncta tid gtid')
mov.u16 tmp, %ctaid.x;
cvt.u32.u16 cta, tmp;
mov.u16 tmp, %ntid.x;
cvt.u32.u16 ncta, tmp;
mul.lo.u32 gtid, cta, ncta;
op.mov.u16(tmp, '%ctaid.x')
op.cvt.u32.u16(cta, tmp)
op.mov.u16(tmp, '%ntid.x')
op.cvt.u32.u16(ncta, tmp)
op.mul.lo.u32(gtid, cta, ncta)
mov.u16 tmp, %tid.x;
cvt.u32.u16 tid, tmp;
add.u32 gtid, gtid, tid;
mov.b32 """ + dst + ", gtid;\n}"
op.mov.u16(tmp, '%tid.x')
op.cvt.u32.u16(tid, tmp)
op.add.u32(gtid, gtid, tid)
op.mov.b32(dst, gtid)
def subs(self, ctx):
return {
'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32',
'get_gtid': self._get_gtid
}
def to_inject(self):
return dict(
_block=self.block,
block=Block(self.block),
op=Op(self.block),
reg=_RegFactory(self.block),
mem=_MemFactory(self.block),
addr=Mem.addr,
vec=Mem.vec,
label=_LabelFactory(self.block),
get_gtid=self._get_gtid)
class PTXModule(object):
"""
Assembles PTX fragments into a module. The following properties are
available:
`instances`: Mapping of type to instance for the PTXFragments used in
the creation of this PTXModule.
`entries`: List of PTXEntry types in this module, including any tests.
`tests`: List of PTXTest types in this module.
`source`: PTX source code for this module.
"""
max_compiles = 10
def __init__(self, entries, inject={}, build_tests=False, formatter=None):
"""
Construct a PTXModule.
`entries`: List of PTXEntry types to include in this module.
`inject`: Dict of items to inject into the DSL namespace.
`build_tests`: If true, build tests into the module.
`formatter`: PTXFormatter instance, or None to use defaults.
"""
block = _Block()
insts, tests, all_deps, entry_deps = (
self.deptrace(block, entries, build_tests))
self.instances = insts
self.tests = tests
inject = dict(inject)
self._safeupdate(inject, {'module': self})
for inst in all_deps:
self._safeupdate(inject, inst.to_inject())
[block.inject(k, v) for k, v in inject.items()]
self.__needs_recompilation = True
self.compiles = 0
while self.__needs_recompilation:
self.compiles += 1
self.__needs_recompilation = False
self.assemble(block, all_deps, entry_deps)
self.instances.pop(_PTXStdLib)
print self.instances
if not formatter:
formatter = PTXFormatter()
self.source = formatter.format(block.outer_ctx.code)
self.entries = list(set(entries + tests))
def deporder(self, unsorted_instances, instance_map):
"""
Do a DFS on PTXFragment dependencies, and return an ordered list of
instances where no fragment depends on any before it in the list.
`unsorted_instances` is the list of instances to sort.
`instance_map` is a dict of types to instances.
"""
seen = {}
def rec(inst):
if inst in seen: return seen[inst]
if inst is None: return 0
deps = filter(lambda d: d is not inst,
map(instance_map.get, inst.deps()))
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
"""dst.update(src), but no duplicates allowed"""
non_uniq = [k for k in src if k in dst]
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
dst.update(src)
def deptrace(self, block, entries, build_tests):
instances = {_PTXStdLib: _PTXStdLib(block)}
unvisited_entries = list(entries)
tests = set()
entry_deps = {}
# For each PTXEntry or PTXTest, use a BFS to recursively find and
# instantiate all fragments that are dependencies. If tests are
# discovered, add those to the list of entries.
while unvisited_entries:
ent = unvisited_entries.pop(0)
seen, unvisited = set(), [ent]
while unvisited:
frag = unvisited.pop(0)
seen.add(frag)
# setdefault doesn't work because of _PTXStdLib
if frag not in instances:
inst = frag()
instances[frag] = inst
else:
inst = instances[frag]
for dep in inst.deps():
if dep not in seen:
unvisited.append(dep)
if build_tests:
for test in inst.tests():
if test not in tests:
tests.add(test)
if test not in instances:
unvisisted_entries.append(tests)
# For this entry, store insts of all dependencies in order.
entry_deps[ent] = self.deporder(map(instances.get, seen),
instances)
# Find the order for all dependencies in the program.
all_deps = self.deporder(instances.values(), instances)
return instances, sorted(tests, key=str), all_deps, entry_deps
def assemble(self, block, all_deps, entry_deps):
# Rebind to local namespace to allow proper retrieval
_block = block
for inst in all_deps:
inst.module_setup()
for ent, insts in entry_deps.items():
# This is kind of hackish compared to everything else
params = [Reg('.param.' + str(type), name)
for (type, name) in ent.entry_params]
_block.code(op='.entry %s ' % ent.entry_name, semi=False,
vars=['(', ['%s %s' % (r.type, r.name) for r in params], ')'])
with Block(_block):
[_block.inject(r.name, r) for r in params]
for dep in insts:
dep.entry_setup()
self.instances[ent].entry()
for dep in reversed(insts):
dep.entry_teardown()
for inst in all_deps:
inst.finalize_code()
def set_needs_recompilation(self):
if not self.__needs_recompilation:
if self.compiles >= self.max_compiles:
raise ValueError("Too many recompiles scheduled!")
self.__needs_recompilation = True
class PTXFormatter(object):
"""
Formats PTXStmt items into beautiful code. Well, the beautiful part is
postponed for now.
"""
def __init__(self, indent=4):
self.indent_amt = 4
def _flatten(self, val):
if isinstance(val, (list, tuple)):
return ''.join(map(self._flatten, val))
return str(val)
def format(self, code):
out = []
indent = 0
for (pfx, op, vars, semi, indent_change) in code:
pfx = self._flatten(pfx)
op = self._flatten(op)
vars = map(self._flatten, vars)
if indent_change < 0:
indent = max(0, indent + self.indent_amt * indent_change)
# TODO: make this a lot prettier
line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars))
if semi:
line = line.rstrip() + ';'
out.append(line)
if indent_change > 0:
indent += self.indent_amt * indent_change
return '\n'.join(out)

10
main.py
View File

@ -15,16 +15,16 @@ from ctypes import *
import numpy as np
#from cuburnlib.device_code import MWCRNGTest
#from cuburnlib.cuda import LaunchContext
from cuburnlib.device_code import MWCRNGTest
from cuburnlib.cuda import LaunchContext
from fr0stlib.pyflam3 import *
from fr0stlib.pyflam3._flam3 import *
from cuburnlib.render import *
def main(genome_path):
#ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
#ctx.compile(True)
#ctx.run_tests()
ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
ctx.compile(verbose=True)
ctx.run_tests()
with open(genome_path) as fp:
genomes = Genome.from_string(fp.read())