mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 03:30:05 -05:00
PTX DSL working, at least well enough to pass MWCRNGTest
This commit is contained in:
parent
5f8c2bbf08
commit
a3660ec6e4
@ -10,7 +10,7 @@ import pycuda.gl.autoinit
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cuburnlib.ptx import PTXAssembler
|
||||
from cuburnlib.ptx import PTXModule
|
||||
|
||||
class LaunchContext(object):
|
||||
"""
|
||||
@ -44,8 +44,10 @@ class LaunchContext(object):
|
||||
def threads(self):
|
||||
return reduce(lambda a, b: a*b, self.block + self.grid)
|
||||
|
||||
def compile(self, verbose=False):
|
||||
self.ptx = PTXAssembler(self, self.entry_types, self.build_tests)
|
||||
def compile(self, to_inject={}, verbose=False):
|
||||
inj = dict(to_inject)
|
||||
inj['ctx'] = self
|
||||
self.ptx = PTXModule(self.entry_types, inj, self.build_tests)
|
||||
try:
|
||||
self.mod = cuda.module_from_buffer(self.ptx.source)
|
||||
except (cuda.CompileError, cuda.RuntimeError), e:
|
||||
@ -54,15 +56,16 @@ class LaunchContext(object):
|
||||
enumerate(self.ptx.source.split('\n'))])
|
||||
raise e
|
||||
if verbose:
|
||||
for name in self.ptx.entry_names.values():
|
||||
func = self.mod.get_function(name)
|
||||
print "Compiled %s: used %d regs, %d sm, %d local" % (func,
|
||||
func.num_regs, func.shared_size_bytes, func.local_size_bytes)
|
||||
for entry in self.ptx.entries:
|
||||
func = self.mod.get_function(entry.entry_name)
|
||||
print "Compiled %s: used %d regs, %d sm, %d local" % (
|
||||
entry.entry_name, func.num_regs,
|
||||
func.shared_size_bytes, func.local_size_bytes)
|
||||
|
||||
def set_up(self):
|
||||
for inst in self.ptx.deporder(self.ptx.instances.values(),
|
||||
self.ptx.instances, self):
|
||||
inst.set_up(self)
|
||||
self.ptx.instances):
|
||||
inst.device_init(self)
|
||||
|
||||
def run(self):
|
||||
if not self.setup_done: self.set_up()
|
||||
|
@ -8,7 +8,7 @@ import time
|
||||
import pycuda.driver as cuda
|
||||
import numpy as np
|
||||
|
||||
from cuburnlib.ptx import PTXFragment, PTXEntryPoint, PTXTest
|
||||
from cuburnlib.ptx import *
|
||||
|
||||
"""
|
||||
Here's the current draft of the full algorithm implementation.
|
||||
@ -113,10 +113,12 @@ class MWCRNG(PTXFragment):
|
||||
if not os.path.isfile('primes.bin'):
|
||||
raise EnvironmentError('primes.bin not found')
|
||||
|
||||
@ptx_func
|
||||
def module_setup(self):
|
||||
mem.global_.u32('mwc_rng_mults', ctx.threads)
|
||||
mem.global_.u32('mwc_rng_state', ctx.threads)
|
||||
mem.global_.u64('mwc_rng_state', ctx.threads)
|
||||
|
||||
@ptx_func
|
||||
def entry_setup(self):
|
||||
reg.u32('mwc_st mwc_mult mwc_car')
|
||||
with block('Load MWC multipliers and states'):
|
||||
@ -130,6 +132,7 @@ class MWCRNG(PTXFragment):
|
||||
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
|
||||
op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr))
|
||||
|
||||
@ptx_func
|
||||
def entry_teardown(self):
|
||||
with block('Save MWC states'):
|
||||
reg.u32('mwc_off mwc_addr')
|
||||
@ -138,15 +141,19 @@ class MWCRNG(PTXFragment):
|
||||
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
|
||||
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
|
||||
|
||||
@ptx_func
|
||||
def next_b32(self, dst_reg):
|
||||
with block('Load next random into ' + dst_reg.name):
|
||||
reg.u64('mwc_out')
|
||||
op.cvt.u64.u32(mwc_out, mwc_car)
|
||||
mad.wide.u32(mwc_out, mwc_st)
|
||||
mov.b64(vec(mwc_st, mwc_car), mwc_out)
|
||||
mov.u32(dst_reg, mwc_st)
|
||||
op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out)
|
||||
op.mov.b64(vec(mwc_st, mwc_car), mwc_out)
|
||||
op.mov.u32(dst_reg, mwc_st)
|
||||
|
||||
def set_up(self, ctx):
|
||||
def to_inject(self):
|
||||
return dict(mwc_next_b32=self.next_b32)
|
||||
|
||||
def device_init(self, ctx):
|
||||
if self.threads_ready >= ctx.threads:
|
||||
# Already set up enough random states, don't push again
|
||||
return
|
||||
@ -168,21 +175,25 @@ class MWCRNG(PTXFragment):
|
||||
states = np.array(ctx.rand.randint(1, 0xffffffff, size=2*ctx.threads),
|
||||
dtype=np.uint32)
|
||||
statedp, statel = ctx.mod.get_global('mwc_rng_state')
|
||||
print states, len(states.tostring())
|
||||
cuda.memcpy_htod_async(statedp, states.tostring())
|
||||
self.threads_ready = ctx.threads
|
||||
|
||||
def tests(self, ctx):
|
||||
def tests(self):
|
||||
return [MWCRNGTest]
|
||||
|
||||
class MWCRNGTest(PTXTest):
|
||||
name = "MWC RNG sum-of-threads"
|
||||
deps = [MWCRNG]
|
||||
rounds = 10000
|
||||
entry_name = 'MWC_RNG_test'
|
||||
entry_params = ''
|
||||
|
||||
def deps(self):
|
||||
return [MWCRNG]
|
||||
|
||||
@ptx_func
|
||||
def module_setup(self):
|
||||
mem.global_.u64(mwc_rng_test_sums, ctx.threads)
|
||||
mem.global_.u64('mwc_rng_test_sums', ctx.threads)
|
||||
|
||||
@ptx_func
|
||||
def entry(self):
|
||||
@ -191,7 +202,7 @@ class MWCRNGTest(PTXTest):
|
||||
op.mov.u64(sum, 0)
|
||||
with block('Sum next %d random numbers' % self.rounds):
|
||||
reg.u32('loopct')
|
||||
pred('p')
|
||||
reg.pred('p')
|
||||
op.mov.u32(loopct, self.rounds)
|
||||
label('loopstart')
|
||||
mwc_next_b32(addend)
|
||||
@ -206,7 +217,7 @@ class MWCRNGTest(PTXTest):
|
||||
get_gtid(offset)
|
||||
op.mov.u32(adr, mwc_rng_test_sums)
|
||||
op.mad.lo.u32(adr, offset, 8, adr)
|
||||
st.global_.u64(addr(adr), sum)
|
||||
op.st.global_.u64(addr(adr), sum)
|
||||
|
||||
def call(self, ctx):
|
||||
# Get current multipliers and seeds from the device
|
||||
|
549
cuburnlib/ptx.py
549
cuburnlib/ptx.py
@ -10,7 +10,8 @@ easier to maintain using this system.
|
||||
|
||||
# If you see 'import inspect', you know you're in for a good time
|
||||
import inspect
|
||||
import ctypes
|
||||
import types
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
# Okay, so here's what's going on.
|
||||
@ -23,7 +24,7 @@ from collections import namedtuple
|
||||
# splitting things up at the level of PTX will greatly reduce performance, as
|
||||
# the cost of accessing the stack, spilling registers, and reloading data from
|
||||
# system memory is unacceptably high even on Fermi GPUs. So we want to split
|
||||
# code up into functions within Python, but not within the PTX.
|
||||
# code up into functions within Python, but not within the PTX source.
|
||||
#
|
||||
# The challenge here is variable lifetime. A PTX function might declare a
|
||||
# register at the top of the main block and use it several times throughout the
|
||||
@ -50,10 +51,10 @@ from collections import namedtuple
|
||||
# reg.u32('hooray_reg')
|
||||
# load_zero(hooray_reg)
|
||||
#
|
||||
# But using blocks to track state, it would turn in to this ugliness::
|
||||
# But using blocks alone to track names, it would turn in to this ugliness::
|
||||
#
|
||||
# def load_zero(block, dest_reg):
|
||||
# block.op.mov.u32(op.dest_reg, 0)
|
||||
# block.op.mov.u32(block.op.dest_reg, 0)
|
||||
# def init_module():
|
||||
# with Block() as block:
|
||||
# block.regs.hooray_reg = block.reg.u32('hooray_reg')
|
||||
@ -70,6 +71,9 @@ from collections import namedtuple
|
||||
# below give a clear picture of how to use it, but now you know why this
|
||||
# abomination was crafted to begin with.
|
||||
|
||||
def _softjoin(args, sep):
|
||||
"""Intersperses 'sep' between 'args' without coercing to string."""
|
||||
return reduce(lambda l, x: l + [x, sep], args, [])[:-1]
|
||||
|
||||
BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
|
||||
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
|
||||
@ -100,6 +104,13 @@ class _BlockInjector(object):
|
||||
else:
|
||||
self.inject_into[k] = v
|
||||
self.injected.add(k)
|
||||
def pop(self, keys):
|
||||
"""Remove keys from a dictionary, as long as we added them."""
|
||||
assert not self.dead
|
||||
for k in keys:
|
||||
if k in self.injected:
|
||||
self.inject_into.pop(k)
|
||||
self.injected.remove(k)
|
||||
def __enter__(self):
|
||||
self.dead = False
|
||||
map(self.inject, self.to_inject.items())
|
||||
@ -115,24 +126,49 @@ class _Block(object):
|
||||
|
||||
For important reasons, the instance must be bound locally as "_block".
|
||||
"""
|
||||
name = '_block'
|
||||
name = '_block' # For retrieving from parent scope on first call
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
def reset(self):
|
||||
self.outer_ctx = BlockCtx({self.name: self}, [], [])
|
||||
self.stack = [self.outer_ctx]
|
||||
def clean_injectors(self):
|
||||
inj = self.stack[-1].injectors
|
||||
[inj.remove(i) for i in inj if i.dead]
|
||||
def push_ctx(self):
|
||||
self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], []))
|
||||
# Move most recent active injector to new context
|
||||
self.clean_injectors()
|
||||
last_inj = self.stack[-1].injectors.pop()
|
||||
self.stack.append(BlockCtx(dict(self.stack[-1].locals), [],
|
||||
[last_inj]))
|
||||
def pop_ctx(self):
|
||||
self.clean_injectors()
|
||||
bs = self.stack.pop()
|
||||
self.stack[-1].code.append(bs.code)
|
||||
self.stack[-1].code.extend(bs.code)
|
||||
if len(self.stack) == 1:
|
||||
# We're on outer_ctx, so all injectors should be gone
|
||||
assert len(bs.injectors) == 0, "Injector/context mismatch"
|
||||
return
|
||||
# The only injector should be the one added in push_ctx
|
||||
assert len(bs.injectors) == 1, "Injector/context mismatch"
|
||||
# Find out which keys were injected while in this context
|
||||
diff = set(bs.locals.keys()).difference(
|
||||
set(self.stack[-1].locals.keys()))
|
||||
# Pop keys and move current injector back down to last context
|
||||
last_inj = bs.injectors.pop()
|
||||
last_inj.pop(diff)
|
||||
self.stack[-1].injectors.append(last_inj)
|
||||
def injector(self, func_globals):
|
||||
inj = BlockInjector(self.stack[-1].locals, func_globals)
|
||||
inj = _BlockInjector(dict(self.stack[-1].locals), func_globals)
|
||||
self.stack[-1].injectors.append(inj)
|
||||
return inj
|
||||
def inject(self, name, object):
|
||||
if name in self.stack[-1].locals:
|
||||
raise KeyError("Duplicate name already exists in this scope.")
|
||||
self.stack[-1].locals[name] = object
|
||||
[inj.inject(name, object) for inj in self.stack[-1].injectors]
|
||||
if self.stack[-1].locals[name] is not object:
|
||||
raise KeyError("'%s' already exists in this scope." % name)
|
||||
else:
|
||||
self.stack[-1].locals[name] = object
|
||||
[inj.inject(name, object) for inj in self.stack[-1].injectors]
|
||||
def code(self, prefix='', op='', vars=[], semi=True, indent=0):
|
||||
"""
|
||||
Append a PTX statement (or thereabouts) to the current block.
|
||||
@ -157,7 +193,7 @@ class _Block(object):
|
||||
yes, the only real difference between `prefix`, `op`, and `vars` is in
|
||||
final appearance, but it is in fact quite helpful for debugging.
|
||||
"""
|
||||
self.stack[-1].append(PTXStmt(prefix, op, vars, indent))
|
||||
self.stack[-1].code.append(PTXStmt(prefix, op, vars, semi, indent))
|
||||
|
||||
class StrVar(object):
|
||||
"""
|
||||
@ -168,28 +204,50 @@ class StrVar(object):
|
||||
def __str__(self):
|
||||
return str(val)
|
||||
|
||||
class _PTXFuncWrapper(object):
|
||||
"""Enables ptx_func"""
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
def __call__(self, *args, **kwargs):
|
||||
if _Block.name in globals():
|
||||
block = globals()['block']
|
||||
else:
|
||||
# Find the '_block' from the enclosing scope
|
||||
parent = inspect.stack()[2][0]
|
||||
if _Block.name in parent.f_locals:
|
||||
block = parent.f_locals[_Block.name]
|
||||
elif _Block.name in parent.f_globals:
|
||||
block = parent.f_globals[_Block.name]
|
||||
else:
|
||||
# Couldn't find the _block instance. Fail cryptically to
|
||||
# encourage users to read the source (for now)
|
||||
raise SyntaxError("Black magic")
|
||||
# Create a new function with the modified scope and call it. We could
|
||||
# do this in __init__, but it would hide any changes to globals from
|
||||
# the module's original scope. Still an option if performance sucks.
|
||||
newglobals = dict(self.func.func_globals)
|
||||
func = types.FunctionType(self.func.func_code, newglobals,
|
||||
self.func.func_name, self.func.func_defaults,
|
||||
self.func.func_closure)
|
||||
# TODO: if we generate a new dict every time, we can kill the
|
||||
# _BlockInjector and move BI.inject() back to _Block, but I don't want
|
||||
# to delete working code just yet
|
||||
with block.injector(func.func_globals):
|
||||
func(*args, **kwargs)
|
||||
|
||||
def ptx_func(func):
|
||||
"""
|
||||
Decorator function for code in the DSL. Any function which accesses the DSL
|
||||
namespace, including declared device variables and objects such as "reg"
|
||||
or "op", should be wrapped with this. See Block for some examples.
|
||||
|
||||
Note that writes to global variables will silently fail for now.
|
||||
"""
|
||||
def ptx_eval(*args, **kwargs):
|
||||
if self.name not in globals():
|
||||
parent = inspect.stack()[-2][0]
|
||||
if self.name in parent.f_locals:
|
||||
block = parent.f_locals[self.name]
|
||||
elif self.name in parent.f_globals:
|
||||
block = parent.f_globals[self.name]
|
||||
else:
|
||||
# Couldn't find the _block instance. Fail cryptically to
|
||||
# encourage users to read the source (for now)
|
||||
raise SyntaxError("Black magic")
|
||||
else:
|
||||
block = globals()['block']
|
||||
with block.injector(func.func_globals):
|
||||
func(*args, **kwargs)
|
||||
return ptx_eval
|
||||
# Attach most of the code to the wrapper class
|
||||
fw = _PTXFuncWrapper(func)
|
||||
def wr(*args, **kwargs):
|
||||
fw(*args, **kwargs)
|
||||
return wr
|
||||
|
||||
class Block(object):
|
||||
"""
|
||||
@ -231,14 +289,17 @@ class Block(object):
|
||||
# `block` is the real _block
|
||||
self.block = block
|
||||
self.comment = None
|
||||
def __call__(self, comment=None)
|
||||
def __call__(self, comment=None):
|
||||
self.comment = comment
|
||||
return self
|
||||
def __enter__(self):
|
||||
self.block.push_ctx()
|
||||
self.block.code(op='{', indent=4)
|
||||
self.block.code(op='{', indent=1, semi=False)
|
||||
if self.comment:
|
||||
self.block.code(op=['// ', self.comment], semi=False)
|
||||
self.comment = None
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
self.block.code(op='}', indent=-4)
|
||||
self.block.code(op='}', indent=-1, semi=False)
|
||||
self.block.pop_ctx()
|
||||
|
||||
class _CallChain(object):
|
||||
@ -248,12 +309,12 @@ class _CallChain(object):
|
||||
self.__chain = []
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert(self.__chain)
|
||||
self._call(chain, *args, **kwargs)
|
||||
self._call(self.__chain, *args, **kwargs)
|
||||
self.__chain = []
|
||||
def __getattr__(self, name):
|
||||
if name == 'global_':
|
||||
name = 'global'
|
||||
self.chain.append(name)
|
||||
self.__chain.append(name)
|
||||
# Another great crime against the universe:
|
||||
return self
|
||||
|
||||
@ -284,7 +345,7 @@ class _RegFactory(_CallChain):
|
||||
type = type[0]
|
||||
names = names.split()
|
||||
regs = map(lambda n: Reg(type, n), names)
|
||||
self.block.code(op='.reg .' + type, vars=names)
|
||||
self.block.code(op='.reg .' + type, vars=_softjoin(names, ', '))
|
||||
[self.block.inject(r.name, r) for r in regs]
|
||||
|
||||
# Pending resolution of the op(regs, guard=x) debate
|
||||
@ -318,8 +379,6 @@ class _RegFactory(_CallChain):
|
||||
#self.name = name
|
||||
#def is_set(self, isnot=False):
|
||||
|
||||
|
||||
|
||||
class Op(_CallChain):
|
||||
"""
|
||||
Performs an operation.
|
||||
@ -340,15 +399,15 @@ class Op(_CallChain):
|
||||
|
||||
This constructor is available as 'op' in DSL blocks.
|
||||
"""
|
||||
def _call(self, op, *args, ifp=None, ifnotp=None):
|
||||
def _call(self, op, *args, **kwargs):
|
||||
pred = ''
|
||||
if ifp:
|
||||
if ifnotp:
|
||||
if 'ifp' in kwargs:
|
||||
if 'ifnotp' in kwargs:
|
||||
raise SyntaxError("can't use both, fool")
|
||||
pred = ['@', ifp]
|
||||
if ifnotp:
|
||||
pred = ['@!', ifnotp]
|
||||
self.block.append_code(pred, '.'.join(op), map(str, args))
|
||||
pred = ['@', kwargs['ifp']]
|
||||
if 'ifnotp' in kwargs:
|
||||
pred = ['@!', kwargs['ifnotp']]
|
||||
self.block.code(pred, '.'.join(op), _softjoin(args, ', '))
|
||||
|
||||
class Mem(object):
|
||||
"""
|
||||
@ -381,7 +440,7 @@ class Mem(object):
|
||||
|
||||
>>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
|
||||
"""
|
||||
return ['{', [(a, ', ') for a in args][:-1], '}']
|
||||
return ['{', _softjoin(args, ', '), '}']
|
||||
|
||||
@staticmethod
|
||||
def addr(areg, aoffset=''):
|
||||
@ -397,8 +456,7 @@ class _MemFactory(_CallChain):
|
||||
"""Actual `mem` object"""
|
||||
def _call(self, type, name, array=False, initializer=None):
|
||||
assert len(type) == 2
|
||||
memobj = Mem(type, name, array)
|
||||
self.dsl.inject(name, memobj)
|
||||
memobj = Mem(type, name, array, initializer)
|
||||
if array is True:
|
||||
array = ['[]']
|
||||
elif array:
|
||||
@ -407,11 +465,12 @@ class _MemFactory(_CallChain):
|
||||
array = []
|
||||
if initializer:
|
||||
array += [' = ', initializer]
|
||||
self.block.code(op=['.%s.%s ' % type, name, array])
|
||||
self.block.code(op=['.%s.%s ' % (type[0], type[1]), name, array])
|
||||
self.block.inject(name, memobj)
|
||||
|
||||
class Label(object):
|
||||
"""
|
||||
Specifies the target for a branch. Scoped in PTX? TODO: test.
|
||||
Specifies the target for a branch. Scoped in PTX? TODO: test that it is.
|
||||
|
||||
>>> label('infinite_loop')
|
||||
>>> op.bra.uni('label')
|
||||
@ -426,25 +485,7 @@ class _LabelFactory(object):
|
||||
self.block = block
|
||||
def __call__(self, name):
|
||||
self.block.inject(name, Label(name))
|
||||
|
||||
class PTXFragment(object):
|
||||
def module_setup(self):
|
||||
pass
|
||||
|
||||
def entry_setup(self):
|
||||
pass
|
||||
|
||||
def entry_teardown(self):
|
||||
pass
|
||||
|
||||
def globals(self):
|
||||
pass
|
||||
|
||||
def tests(self):
|
||||
pass
|
||||
|
||||
def device_init(self, ctx):
|
||||
pass
|
||||
self.block.code(prefix='%s:' % name, semi=False)
|
||||
|
||||
class PTXFragment(object):
|
||||
"""
|
||||
@ -470,12 +511,14 @@ class PTXFragment(object):
|
||||
for successful compilation. Circular dependencies are forbidden,
|
||||
but multi-level dependencies should be fine.
|
||||
"""
|
||||
return [DeviceHelpers]
|
||||
return [_PTXStdLib]
|
||||
|
||||
def inject(self):
|
||||
def to_inject(self):
|
||||
"""
|
||||
Returns a dict of items to add to the DSL namespace. The namespace will
|
||||
be assembled in dependency order before any ptx_funcs are called.
|
||||
|
||||
This is only called once per PTXModule (== once per instance).
|
||||
"""
|
||||
return {}
|
||||
|
||||
@ -483,8 +526,8 @@ class PTXFragment(object):
|
||||
"""
|
||||
PTX function to declare things at module scope. It's a PTX syntax error
|
||||
to perform operations at this scope, but we don't yet validate that at
|
||||
the Python level. A module will call this function on all fragments in
|
||||
dependency order.
|
||||
the Python level. A module will call this function on all fragments
|
||||
used in that module in dependency order.
|
||||
|
||||
If implemented, this function should use an @ptx_func decorator.
|
||||
"""
|
||||
@ -513,131 +556,47 @@ class PTXFragment(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def tests(self, ctx):
|
||||
def finalize_code(self):
|
||||
"""
|
||||
Returns a list of PTXTest classes which will test this fragment.
|
||||
Called after running all PTX DSL functions, but before code generation,
|
||||
to allow fragments which postponed variable evaluation (e.g. using
|
||||
`StrVar`) to fill in the resulting values. Most fragments should not
|
||||
use this.
|
||||
|
||||
If implemented, this function *may* use an @ptx_func decorator to
|
||||
access the global DSL scope, but pretty please don't emit any code
|
||||
while you're in there.
|
||||
"""
|
||||
pass
|
||||
|
||||
def tests(self):
|
||||
"""
|
||||
Returns a list of PTXTest types which will test this fragment.
|
||||
"""
|
||||
return []
|
||||
|
||||
def set_up(self, ctx):
|
||||
def device_init(self, ctx):
|
||||
"""
|
||||
Do start-of-stream initialization, such as copying data to the device.
|
||||
Do stuff on the host to prepare the device for execution. 'ctx' is a
|
||||
LaunchContext or similar. This will get called (in dependency order, of
|
||||
course) *either* before any entry point invocation, or before *each*
|
||||
invocation, I'm not sure which yet. (For now it's "each".)
|
||||
"""
|
||||
pass
|
||||
class PTXModule(object):
|
||||
"""
|
||||
Assembles PTX fragments into a module.
|
||||
"""
|
||||
|
||||
def __init__(self, entries, inject={}, build_tests=False):
|
||||
self._block = b = _Block()
|
||||
self.initial_inject = dict(inject)
|
||||
self._safeupdate(self.initial_inject, dict(block=Block(b),
|
||||
mem=_MemFactory(b), reg=_RegFactory(b), op=Op(b),
|
||||
label=_LabelFactory(b), _block=b)
|
||||
self.needs_recompilation = True
|
||||
self.max_compiles = 10
|
||||
while self.needs_recompilation:
|
||||
self.assemble(entries, build_tests)
|
||||
self.max_compiles -= 1
|
||||
|
||||
def deporder(self, unsorted_instances, instance_map, ctx):
|
||||
"""
|
||||
Do a DFS on PTXFragment dependencies, and return an ordered list of
|
||||
instances where no fragment depends on any before it in the list.
|
||||
|
||||
`unsorted_instances` is the list of instances to sort.
|
||||
`instance_map` is a dict of types to instances.
|
||||
"""
|
||||
seen = {}
|
||||
def rec(inst):
|
||||
if inst in seen: return seen[inst]
|
||||
deps = filter(lambda d: d is not inst, map(instance_map.get,
|
||||
callable(inst.deps) and inst.deps(self) or inst.deps))
|
||||
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
|
||||
map(rec, unsorted_instances)
|
||||
return sorted(unsorted_instances, key=seen.get)
|
||||
|
||||
def _safeupdate(self, dst, src):
|
||||
"""dst.update(src), but no duplicates allowed"""
|
||||
non_uniq = [k for k in src if k in dst]
|
||||
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
|
||||
dst.update(src)
|
||||
|
||||
def assemble(self, entries, build_tests):
|
||||
"""
|
||||
Build the PTX source for the given set of entries.
|
||||
"""
|
||||
# Get a property, dealing with the callable-or-data thing. This is
|
||||
# cumbersome, but flexible; when finished, it may be simplified.
|
||||
def pget(prop):
|
||||
if callable(prop): return prop(ctx)
|
||||
return prop
|
||||
|
||||
instances = {}
|
||||
unvisited_entries = list(entries)
|
||||
entry_names = {}
|
||||
tests = []
|
||||
parsed_entries = []
|
||||
while unvisited_entries:
|
||||
ent = unvisited_entries.pop(0)
|
||||
seen, unvisited = set(), [ent]
|
||||
while unvisited:
|
||||
frag = unvisited.pop(0)
|
||||
seen.add(frag)
|
||||
inst = instances.setdefault(frag, frag())
|
||||
for dep in pget(inst.deps):
|
||||
if dep not in seen:
|
||||
unvisited.append(dep)
|
||||
if build_tests:
|
||||
for test in pget(inst.tests):
|
||||
if test not in tests:
|
||||
if test not in instances:
|
||||
unvisited_entries.append(test)
|
||||
tests.append(test)
|
||||
|
||||
tmpl_namespace = {'ctx': ctx}
|
||||
entry_start, entry_end = [], []
|
||||
for inst in self.deporder(map(instances.get, seen), instances, ctx):
|
||||
self._safeupdate(tmpl_namespace, pget(inst.subs))
|
||||
entry_start.append(pget(inst.entry_start))
|
||||
entry_end.append(pget(inst.entry_end))
|
||||
entry_start_tmpl = '\n'.join(filter(None, entry_start))
|
||||
entry_end_tmpl = '\n'.join(filter(None, reversed(entry_end)))
|
||||
name, args, body = pget(instances[ent].entry)
|
||||
tmpl_namespace.update({'_entry_name_': name, '_entry_args_': args,
|
||||
'_entry_body_': body, '_entry_start_': entry_start_tmpl,
|
||||
'_entry_end_': entry_end_tmpl})
|
||||
|
||||
entry_tmpl = (".entry {{ _entry_name_ }} ({{ _entry_args_ }})\n"
|
||||
"{\n{{_entry_start_}}\n{{_entry_body_}}\n{{_entry_end_}}\n}\n")
|
||||
parsed_entries.append(multisub(entry_tmpl, tmpl_namespace))
|
||||
entry_names[ent] = name
|
||||
|
||||
prelude = []
|
||||
tmpl_namespace = {'ctx': ctx}
|
||||
for inst in self.deporder(instances.values(), instances, ctx):
|
||||
prelude.append(pget(inst.prelude))
|
||||
self._safeupdate(tmpl_namespace, pget(inst.subs))
|
||||
tmpl_namespace['_prelude_'] = '\n'.join(filter(None, prelude))
|
||||
tmpl_namespace['_entries_'] = '\n\n'.join(parsed_entries)
|
||||
tmpl = "{{ _prelude_ }}\n{{ _entries_ }}"
|
||||
|
||||
self.entry_names = entry_names
|
||||
self.source = ppr_ptx(multisub(tmpl, tmpl_namespace))
|
||||
self.instances = instances
|
||||
self.tests = tests
|
||||
|
||||
|
||||
|
||||
class PTXEntryPoint(PTXFragment):
|
||||
# Human-readable entry point name
|
||||
name = ""
|
||||
# Device code entry name
|
||||
entry_name = ""
|
||||
# List of (type, name) pairs for entry params, e.g. [('u32', 'thing')]
|
||||
entry_params = []
|
||||
|
||||
def entry(self, ctx):
|
||||
"""
|
||||
Returns a 3-tuple of (name, args, body), which will be assembled into
|
||||
a function.
|
||||
PTX DSL function that comprises the body of the PTX statement.
|
||||
|
||||
Must be implemented and decorated with ptx_func.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -660,32 +619,216 @@ class PTXTest(PTXEntryPoint):
|
||||
"""
|
||||
pass
|
||||
|
||||
class DeviceHelpers(PTXFragment):
|
||||
def __init__(self):
|
||||
self._forstack = []
|
||||
class _PTXStdLib(PTXFragment):
|
||||
def __init__(self, block):
|
||||
# Only module that gets the privilege of seeing 'block' directly.
|
||||
self.block = block
|
||||
|
||||
prelude = ".version 2.1\n.target sm_20\n\n"
|
||||
def deps(self):
|
||||
return []
|
||||
|
||||
@ptx_func
|
||||
def module_setup(self):
|
||||
# TODO: make this modular, maybe? of course, we'd have to support
|
||||
# multiple devices first, which we definitely do not yet do
|
||||
self.block.code(prefix='.version 2.1', semi=False)
|
||||
self.block.code(prefix='.target sm_20', semi=False)
|
||||
|
||||
@ptx_func
|
||||
def _get_gtid(self, dst):
|
||||
return "{\n// Load GTID into " + dst + """
|
||||
.reg .u16 tmp;
|
||||
.reg .u32 cta, ncta, tid, gtid;
|
||||
with block("Load GTID into %s" % str(dst)):
|
||||
reg.u16('tmp')
|
||||
reg.u32('cta ncta tid gtid')
|
||||
|
||||
mov.u16 tmp, %ctaid.x;
|
||||
cvt.u32.u16 cta, tmp;
|
||||
mov.u16 tmp, %ntid.x;
|
||||
cvt.u32.u16 ncta, tmp;
|
||||
mul.lo.u32 gtid, cta, ncta;
|
||||
op.mov.u16(tmp, '%ctaid.x')
|
||||
op.cvt.u32.u16(cta, tmp)
|
||||
op.mov.u16(tmp, '%ntid.x')
|
||||
op.cvt.u32.u16(ncta, tmp)
|
||||
op.mul.lo.u32(gtid, cta, ncta)
|
||||
|
||||
mov.u16 tmp, %tid.x;
|
||||
cvt.u32.u16 tid, tmp;
|
||||
add.u32 gtid, gtid, tid;
|
||||
mov.b32 """ + dst + ", gtid;\n}"
|
||||
op.mov.u16(tmp, '%tid.x')
|
||||
op.cvt.u32.u16(tid, tmp)
|
||||
op.add.u32(gtid, gtid, tid)
|
||||
op.mov.b32(dst, gtid)
|
||||
|
||||
def subs(self, ctx):
|
||||
return {
|
||||
'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32',
|
||||
'get_gtid': self._get_gtid
|
||||
}
|
||||
def to_inject(self):
|
||||
return dict(
|
||||
_block=self.block,
|
||||
block=Block(self.block),
|
||||
op=Op(self.block),
|
||||
reg=_RegFactory(self.block),
|
||||
mem=_MemFactory(self.block),
|
||||
addr=Mem.addr,
|
||||
vec=Mem.vec,
|
||||
label=_LabelFactory(self.block),
|
||||
get_gtid=self._get_gtid)
|
||||
|
||||
class PTXModule(object):
|
||||
"""
|
||||
Assembles PTX fragments into a module. The following properties are
|
||||
available:
|
||||
|
||||
`instances`: Mapping of type to instance for the PTXFragments used in
|
||||
the creation of this PTXModule.
|
||||
`entries`: List of PTXEntry types in this module, including any tests.
|
||||
`tests`: List of PTXTest types in this module.
|
||||
`source`: PTX source code for this module.
|
||||
"""
|
||||
max_compiles = 10
|
||||
|
||||
def __init__(self, entries, inject={}, build_tests=False, formatter=None):
|
||||
"""
|
||||
Construct a PTXModule.
|
||||
|
||||
`entries`: List of PTXEntry types to include in this module.
|
||||
`inject`: Dict of items to inject into the DSL namespace.
|
||||
`build_tests`: If true, build tests into the module.
|
||||
`formatter`: PTXFormatter instance, or None to use defaults.
|
||||
"""
|
||||
block = _Block()
|
||||
insts, tests, all_deps, entry_deps = (
|
||||
self.deptrace(block, entries, build_tests))
|
||||
self.instances = insts
|
||||
self.tests = tests
|
||||
|
||||
inject = dict(inject)
|
||||
self._safeupdate(inject, {'module': self})
|
||||
for inst in all_deps:
|
||||
self._safeupdate(inject, inst.to_inject())
|
||||
[block.inject(k, v) for k, v in inject.items()]
|
||||
|
||||
self.__needs_recompilation = True
|
||||
self.compiles = 0
|
||||
while self.__needs_recompilation:
|
||||
self.compiles += 1
|
||||
self.__needs_recompilation = False
|
||||
self.assemble(block, all_deps, entry_deps)
|
||||
self.instances.pop(_PTXStdLib)
|
||||
print self.instances
|
||||
|
||||
if not formatter:
|
||||
formatter = PTXFormatter()
|
||||
self.source = formatter.format(block.outer_ctx.code)
|
||||
self.entries = list(set(entries + tests))
|
||||
|
||||
def deporder(self, unsorted_instances, instance_map):
|
||||
"""
|
||||
Do a DFS on PTXFragment dependencies, and return an ordered list of
|
||||
instances where no fragment depends on any before it in the list.
|
||||
|
||||
`unsorted_instances` is the list of instances to sort.
|
||||
`instance_map` is a dict of types to instances.
|
||||
"""
|
||||
seen = {}
|
||||
def rec(inst):
|
||||
if inst in seen: return seen[inst]
|
||||
if inst is None: return 0
|
||||
deps = filter(lambda d: d is not inst,
|
||||
map(instance_map.get, inst.deps()))
|
||||
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
|
||||
map(rec, unsorted_instances)
|
||||
return sorted(unsorted_instances, key=seen.get)
|
||||
|
||||
def _safeupdate(self, dst, src):
|
||||
"""dst.update(src), but no duplicates allowed"""
|
||||
non_uniq = [k for k in src if k in dst]
|
||||
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
|
||||
dst.update(src)
|
||||
|
||||
def deptrace(self, block, entries, build_tests):
|
||||
instances = {_PTXStdLib: _PTXStdLib(block)}
|
||||
unvisited_entries = list(entries)
|
||||
tests = set()
|
||||
entry_deps = {}
|
||||
|
||||
# For each PTXEntry or PTXTest, use a BFS to recursively find and
|
||||
# instantiate all fragments that are dependencies. If tests are
|
||||
# discovered, add those to the list of entries.
|
||||
while unvisited_entries:
|
||||
ent = unvisited_entries.pop(0)
|
||||
seen, unvisited = set(), [ent]
|
||||
while unvisited:
|
||||
frag = unvisited.pop(0)
|
||||
seen.add(frag)
|
||||
# setdefault doesn't work because of _PTXStdLib
|
||||
if frag not in instances:
|
||||
inst = frag()
|
||||
instances[frag] = inst
|
||||
else:
|
||||
inst = instances[frag]
|
||||
for dep in inst.deps():
|
||||
if dep not in seen:
|
||||
unvisited.append(dep)
|
||||
if build_tests:
|
||||
for test in inst.tests():
|
||||
if test not in tests:
|
||||
tests.add(test)
|
||||
if test not in instances:
|
||||
unvisisted_entries.append(tests)
|
||||
# For this entry, store insts of all dependencies in order.
|
||||
entry_deps[ent] = self.deporder(map(instances.get, seen),
|
||||
instances)
|
||||
# Find the order for all dependencies in the program.
|
||||
all_deps = self.deporder(instances.values(), instances)
|
||||
|
||||
return instances, sorted(tests, key=str), all_deps, entry_deps
|
||||
|
||||
def assemble(self, block, all_deps, entry_deps):
|
||||
# Rebind to local namespace to allow proper retrieval
|
||||
_block = block
|
||||
for inst in all_deps:
|
||||
inst.module_setup()
|
||||
|
||||
for ent, insts in entry_deps.items():
|
||||
# This is kind of hackish compared to everything else
|
||||
params = [Reg('.param.' + str(type), name)
|
||||
for (type, name) in ent.entry_params]
|
||||
_block.code(op='.entry %s ' % ent.entry_name, semi=False,
|
||||
vars=['(', ['%s %s' % (r.type, r.name) for r in params], ')'])
|
||||
with Block(_block):
|
||||
[_block.inject(r.name, r) for r in params]
|
||||
for dep in insts:
|
||||
dep.entry_setup()
|
||||
self.instances[ent].entry()
|
||||
for dep in reversed(insts):
|
||||
dep.entry_teardown()
|
||||
|
||||
for inst in all_deps:
|
||||
inst.finalize_code()
|
||||
|
||||
def set_needs_recompilation(self):
|
||||
if not self.__needs_recompilation:
|
||||
if self.compiles >= self.max_compiles:
|
||||
raise ValueError("Too many recompiles scheduled!")
|
||||
self.__needs_recompilation = True
|
||||
|
||||
class PTXFormatter(object):
|
||||
"""
|
||||
Formats PTXStmt items into beautiful code. Well, the beautiful part is
|
||||
postponed for now.
|
||||
"""
|
||||
def __init__(self, indent=4):
|
||||
self.indent_amt = 4
|
||||
def _flatten(self, val):
|
||||
if isinstance(val, (list, tuple)):
|
||||
return ''.join(map(self._flatten, val))
|
||||
return str(val)
|
||||
def format(self, code):
|
||||
out = []
|
||||
indent = 0
|
||||
for (pfx, op, vars, semi, indent_change) in code:
|
||||
pfx = self._flatten(pfx)
|
||||
op = self._flatten(op)
|
||||
vars = map(self._flatten, vars)
|
||||
if indent_change < 0:
|
||||
indent = max(0, indent + self.indent_amt * indent_change)
|
||||
# TODO: make this a lot prettier
|
||||
line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars))
|
||||
if semi:
|
||||
line = line.rstrip() + ';'
|
||||
out.append(line)
|
||||
if indent_change > 0:
|
||||
indent += self.indent_amt * indent_change
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
|
10
main.py
10
main.py
@ -15,16 +15,16 @@ from ctypes import *
|
||||
|
||||
import numpy as np
|
||||
|
||||
#from cuburnlib.device_code import MWCRNGTest
|
||||
#from cuburnlib.cuda import LaunchContext
|
||||
from cuburnlib.device_code import MWCRNGTest
|
||||
from cuburnlib.cuda import LaunchContext
|
||||
from fr0stlib.pyflam3 import *
|
||||
from fr0stlib.pyflam3._flam3 import *
|
||||
from cuburnlib.render import *
|
||||
|
||||
def main(genome_path):
|
||||
#ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
|
||||
#ctx.compile(True)
|
||||
#ctx.run_tests()
|
||||
ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
|
||||
ctx.compile(verbose=True)
|
||||
ctx.run_tests()
|
||||
|
||||
with open(genome_path) as fp:
|
||||
genomes = Genome.from_string(fp.read())
|
||||
|
Loading…
Reference in New Issue
Block a user