PTX DSL working, at least well enough to pass MWCRNGTest

This commit is contained in:
Steven Robertson 2010-09-01 21:09:40 -04:00
parent 5f8c2bbf08
commit a3660ec6e4
4 changed files with 385 additions and 228 deletions

View File

@ -10,7 +10,7 @@ import pycuda.gl.autoinit
import numpy as np import numpy as np
from cuburnlib.ptx import PTXAssembler from cuburnlib.ptx import PTXModule
class LaunchContext(object): class LaunchContext(object):
""" """
@ -44,8 +44,10 @@ class LaunchContext(object):
def threads(self): def threads(self):
return reduce(lambda a, b: a*b, self.block + self.grid) return reduce(lambda a, b: a*b, self.block + self.grid)
def compile(self, verbose=False): def compile(self, to_inject={}, verbose=False):
self.ptx = PTXAssembler(self, self.entry_types, self.build_tests) inj = dict(to_inject)
inj['ctx'] = self
self.ptx = PTXModule(self.entry_types, inj, self.build_tests)
try: try:
self.mod = cuda.module_from_buffer(self.ptx.source) self.mod = cuda.module_from_buffer(self.ptx.source)
except (cuda.CompileError, cuda.RuntimeError), e: except (cuda.CompileError, cuda.RuntimeError), e:
@ -54,15 +56,16 @@ class LaunchContext(object):
enumerate(self.ptx.source.split('\n'))]) enumerate(self.ptx.source.split('\n'))])
raise e raise e
if verbose: if verbose:
for name in self.ptx.entry_names.values(): for entry in self.ptx.entries:
func = self.mod.get_function(name) func = self.mod.get_function(entry.entry_name)
print "Compiled %s: used %d regs, %d sm, %d local" % (func, print "Compiled %s: used %d regs, %d sm, %d local" % (
func.num_regs, func.shared_size_bytes, func.local_size_bytes) entry.entry_name, func.num_regs,
func.shared_size_bytes, func.local_size_bytes)
def set_up(self): def set_up(self):
for inst in self.ptx.deporder(self.ptx.instances.values(), for inst in self.ptx.deporder(self.ptx.instances.values(),
self.ptx.instances, self): self.ptx.instances):
inst.set_up(self) inst.device_init(self)
def run(self): def run(self):
if not self.setup_done: self.set_up() if not self.setup_done: self.set_up()

View File

@ -8,7 +8,7 @@ import time
import pycuda.driver as cuda import pycuda.driver as cuda
import numpy as np import numpy as np
from cuburnlib.ptx import PTXFragment, PTXEntryPoint, PTXTest from cuburnlib.ptx import *
""" """
Here's the current draft of the full algorithm implementation. Here's the current draft of the full algorithm implementation.
@ -113,10 +113,12 @@ class MWCRNG(PTXFragment):
if not os.path.isfile('primes.bin'): if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found') raise EnvironmentError('primes.bin not found')
@ptx_func
def module_setup(self): def module_setup(self):
mem.global_.u32('mwc_rng_mults', ctx.threads) mem.global_.u32('mwc_rng_mults', ctx.threads)
mem.global_.u32('mwc_rng_state', ctx.threads) mem.global_.u64('mwc_rng_state', ctx.threads)
@ptx_func
def entry_setup(self): def entry_setup(self):
reg.u32('mwc_st mwc_mult mwc_car') reg.u32('mwc_st mwc_mult mwc_car')
with block('Load MWC multipliers and states'): with block('Load MWC multipliers and states'):
@ -130,6 +132,7 @@ class MWCRNG(PTXFragment):
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr)) op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr))
@ptx_func
def entry_teardown(self): def entry_teardown(self):
with block('Save MWC states'): with block('Save MWC states'):
reg.u32('mwc_off mwc_addr') reg.u32('mwc_off mwc_addr')
@ -138,15 +141,19 @@ class MWCRNG(PTXFragment):
op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr)
op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car)) op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car))
@ptx_func
def next_b32(self, dst_reg): def next_b32(self, dst_reg):
with block('Load next random into ' + dst_reg.name): with block('Load next random into ' + dst_reg.name):
reg.u64('mwc_out') reg.u64('mwc_out')
op.cvt.u64.u32(mwc_out, mwc_car) op.cvt.u64.u32(mwc_out, mwc_car)
mad.wide.u32(mwc_out, mwc_st) op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out)
mov.b64(vec(mwc_st, mwc_car), mwc_out) op.mov.b64(vec(mwc_st, mwc_car), mwc_out)
mov.u32(dst_reg, mwc_st) op.mov.u32(dst_reg, mwc_st)
def set_up(self, ctx): def to_inject(self):
return dict(mwc_next_b32=self.next_b32)
def device_init(self, ctx):
if self.threads_ready >= ctx.threads: if self.threads_ready >= ctx.threads:
# Already set up enough random states, don't push again # Already set up enough random states, don't push again
return return
@ -168,21 +175,25 @@ class MWCRNG(PTXFragment):
states = np.array(ctx.rand.randint(1, 0xffffffff, size=2*ctx.threads), states = np.array(ctx.rand.randint(1, 0xffffffff, size=2*ctx.threads),
dtype=np.uint32) dtype=np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state') statedp, statel = ctx.mod.get_global('mwc_rng_state')
print states, len(states.tostring())
cuda.memcpy_htod_async(statedp, states.tostring()) cuda.memcpy_htod_async(statedp, states.tostring())
self.threads_ready = ctx.threads self.threads_ready = ctx.threads
def tests(self, ctx): def tests(self):
return [MWCRNGTest] return [MWCRNGTest]
class MWCRNGTest(PTXTest): class MWCRNGTest(PTXTest):
name = "MWC RNG sum-of-threads" name = "MWC RNG sum-of-threads"
deps = [MWCRNG]
rounds = 10000 rounds = 10000
entry_name = 'MWC_RNG_test' entry_name = 'MWC_RNG_test'
entry_params = '' entry_params = ''
def deps(self):
return [MWCRNG]
@ptx_func
def module_setup(self): def module_setup(self):
mem.global_.u64(mwc_rng_test_sums, ctx.threads) mem.global_.u64('mwc_rng_test_sums', ctx.threads)
@ptx_func @ptx_func
def entry(self): def entry(self):
@ -191,7 +202,7 @@ class MWCRNGTest(PTXTest):
op.mov.u64(sum, 0) op.mov.u64(sum, 0)
with block('Sum next %d random numbers' % self.rounds): with block('Sum next %d random numbers' % self.rounds):
reg.u32('loopct') reg.u32('loopct')
pred('p') reg.pred('p')
op.mov.u32(loopct, self.rounds) op.mov.u32(loopct, self.rounds)
label('loopstart') label('loopstart')
mwc_next_b32(addend) mwc_next_b32(addend)
@ -206,7 +217,7 @@ class MWCRNGTest(PTXTest):
get_gtid(offset) get_gtid(offset)
op.mov.u32(adr, mwc_rng_test_sums) op.mov.u32(adr, mwc_rng_test_sums)
op.mad.lo.u32(adr, offset, 8, adr) op.mad.lo.u32(adr, offset, 8, adr)
st.global_.u64(addr(adr), sum) op.st.global_.u64(addr(adr), sum)
def call(self, ctx): def call(self, ctx):
# Get current multipliers and seeds from the device # Get current multipliers and seeds from the device

View File

@ -10,7 +10,8 @@ easier to maintain using this system.
# If you see 'import inspect', you know you're in for a good time # If you see 'import inspect', you know you're in for a good time
import inspect import inspect
import ctypes import types
import traceback
from collections import namedtuple from collections import namedtuple
# Okay, so here's what's going on. # Okay, so here's what's going on.
@ -23,7 +24,7 @@ from collections import namedtuple
# splitting things up at the level of PTX will greatly reduce performance, as # splitting things up at the level of PTX will greatly reduce performance, as
# the cost of accessing the stack, spilling registers, and reloading data from # the cost of accessing the stack, spilling registers, and reloading data from
# system memory is unacceptably high even on Fermi GPUs. So we want to split # system memory is unacceptably high even on Fermi GPUs. So we want to split
# code up into functions within Python, but not within the PTX. # code up into functions within Python, but not within the PTX source.
# #
# The challenge here is variable lifetime. A PTX function might declare a # The challenge here is variable lifetime. A PTX function might declare a
# register at the top of the main block and use it several times throughout the # register at the top of the main block and use it several times throughout the
@ -50,10 +51,10 @@ from collections import namedtuple
# reg.u32('hooray_reg') # reg.u32('hooray_reg')
# load_zero(hooray_reg) # load_zero(hooray_reg)
# #
# But using blocks to track state, it would turn in to this ugliness:: # But using blocks alone to track names, it would turn in to this ugliness::
# #
# def load_zero(block, dest_reg): # def load_zero(block, dest_reg):
# block.op.mov.u32(op.dest_reg, 0) # block.op.mov.u32(block.op.dest_reg, 0)
# def init_module(): # def init_module():
# with Block() as block: # with Block() as block:
# block.regs.hooray_reg = block.reg.u32('hooray_reg') # block.regs.hooray_reg = block.reg.u32('hooray_reg')
@ -70,6 +71,9 @@ from collections import namedtuple
# below give a clear picture of how to use it, but now you know why this # below give a clear picture of how to use it, but now you know why this
# abomination was crafted to begin with. # abomination was crafted to begin with.
def _softjoin(args, sep):
"""Intersperses 'sep' between 'args' without coercing to string."""
return reduce(lambda l, x: l + [x, sep], args, [])[:-1]
BlockCtx = namedtuple('BlockCtx', 'locals code injectors') BlockCtx = namedtuple('BlockCtx', 'locals code injectors')
PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent') PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent')
@ -100,6 +104,13 @@ class _BlockInjector(object):
else: else:
self.inject_into[k] = v self.inject_into[k] = v
self.injected.add(k) self.injected.add(k)
def pop(self, keys):
"""Remove keys from a dictionary, as long as we added them."""
assert not self.dead
for k in keys:
if k in self.injected:
self.inject_into.pop(k)
self.injected.remove(k)
def __enter__(self): def __enter__(self):
self.dead = False self.dead = False
map(self.inject, self.to_inject.items()) map(self.inject, self.to_inject.items())
@ -115,22 +126,47 @@ class _Block(object):
For important reasons, the instance must be bound locally as "_block". For important reasons, the instance must be bound locally as "_block".
""" """
name = '_block' name = '_block' # For retrieving from parent scope on first call
def __init__(self): def __init__(self):
self.reset()
def reset(self):
self.outer_ctx = BlockCtx({self.name: self}, [], []) self.outer_ctx = BlockCtx({self.name: self}, [], [])
self.stack = [self.outer_ctx] self.stack = [self.outer_ctx]
def clean_injectors(self):
inj = self.stack[-1].injectors
[inj.remove(i) for i in inj if i.dead]
def push_ctx(self): def push_ctx(self):
self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], [])) # Move most recent active injector to new context
self.clean_injectors()
last_inj = self.stack[-1].injectors.pop()
self.stack.append(BlockCtx(dict(self.stack[-1].locals), [],
[last_inj]))
def pop_ctx(self): def pop_ctx(self):
self.clean_injectors()
bs = self.stack.pop() bs = self.stack.pop()
self.stack[-1].code.append(bs.code) self.stack[-1].code.extend(bs.code)
if len(self.stack) == 1:
# We're on outer_ctx, so all injectors should be gone
assert len(bs.injectors) == 0, "Injector/context mismatch"
return
# The only injector should be the one added in push_ctx
assert len(bs.injectors) == 1, "Injector/context mismatch"
# Find out which keys were injected while in this context
diff = set(bs.locals.keys()).difference(
set(self.stack[-1].locals.keys()))
# Pop keys and move current injector back down to last context
last_inj = bs.injectors.pop()
last_inj.pop(diff)
self.stack[-1].injectors.append(last_inj)
def injector(self, func_globals): def injector(self, func_globals):
inj = BlockInjector(self.stack[-1].locals, func_globals) inj = _BlockInjector(dict(self.stack[-1].locals), func_globals)
self.stack[-1].injectors.append(inj) self.stack[-1].injectors.append(inj)
return inj return inj
def inject(self, name, object): def inject(self, name, object):
if name in self.stack[-1].locals: if name in self.stack[-1].locals:
raise KeyError("Duplicate name already exists in this scope.") if self.stack[-1].locals[name] is not object:
raise KeyError("'%s' already exists in this scope." % name)
else:
self.stack[-1].locals[name] = object self.stack[-1].locals[name] = object
[inj.inject(name, object) for inj in self.stack[-1].injectors] [inj.inject(name, object) for inj in self.stack[-1].injectors]
def code(self, prefix='', op='', vars=[], semi=True, indent=0): def code(self, prefix='', op='', vars=[], semi=True, indent=0):
@ -157,7 +193,7 @@ class _Block(object):
yes, the only real difference between `prefix`, `op`, and `vars` is in yes, the only real difference between `prefix`, `op`, and `vars` is in
final appearance, but it is in fact quite helpful for debugging. final appearance, but it is in fact quite helpful for debugging.
""" """
self.stack[-1].append(PTXStmt(prefix, op, vars, indent)) self.stack[-1].code.append(PTXStmt(prefix, op, vars, semi, indent))
class StrVar(object): class StrVar(object):
""" """
@ -168,28 +204,50 @@ class StrVar(object):
def __str__(self): def __str__(self):
return str(val) return str(val)
class _PTXFuncWrapper(object):
"""Enables ptx_func"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
if _Block.name in globals():
block = globals()['block']
else:
# Find the '_block' from the enclosing scope
parent = inspect.stack()[2][0]
if _Block.name in parent.f_locals:
block = parent.f_locals[_Block.name]
elif _Block.name in parent.f_globals:
block = parent.f_globals[_Block.name]
else:
# Couldn't find the _block instance. Fail cryptically to
# encourage users to read the source (for now)
raise SyntaxError("Black magic")
# Create a new function with the modified scope and call it. We could
# do this in __init__, but it would hide any changes to globals from
# the module's original scope. Still an option if performance sucks.
newglobals = dict(self.func.func_globals)
func = types.FunctionType(self.func.func_code, newglobals,
self.func.func_name, self.func.func_defaults,
self.func.func_closure)
# TODO: if we generate a new dict every time, we can kill the
# _BlockInjector and move BI.inject() back to _Block, but I don't want
# to delete working code just yet
with block.injector(func.func_globals):
func(*args, **kwargs)
def ptx_func(func): def ptx_func(func):
""" """
Decorator function for code in the DSL. Any function which accesses the DSL Decorator function for code in the DSL. Any function which accesses the DSL
namespace, including declared device variables and objects such as "reg" namespace, including declared device variables and objects such as "reg"
or "op", should be wrapped with this. See Block for some examples. or "op", should be wrapped with this. See Block for some examples.
Note that writes to global variables will silently fail for now.
""" """
def ptx_eval(*args, **kwargs): # Attach most of the code to the wrapper class
if self.name not in globals(): fw = _PTXFuncWrapper(func)
parent = inspect.stack()[-2][0] def wr(*args, **kwargs):
if self.name in parent.f_locals: fw(*args, **kwargs)
block = parent.f_locals[self.name] return wr
elif self.name in parent.f_globals:
block = parent.f_globals[self.name]
else:
# Couldn't find the _block instance. Fail cryptically to
# encourage users to read the source (for now)
raise SyntaxError("Black magic")
else:
block = globals()['block']
with block.injector(func.func_globals):
func(*args, **kwargs)
return ptx_eval
class Block(object): class Block(object):
""" """
@ -231,14 +289,17 @@ class Block(object):
# `block` is the real _block # `block` is the real _block
self.block = block self.block = block
self.comment = None self.comment = None
def __call__(self, comment=None) def __call__(self, comment=None):
self.comment = comment self.comment = comment
return self return self
def __enter__(self): def __enter__(self):
self.block.push_ctx() self.block.push_ctx()
self.block.code(op='{', indent=4) self.block.code(op='{', indent=1, semi=False)
if self.comment:
self.block.code(op=['// ', self.comment], semi=False)
self.comment = None
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
self.block.code(op='}', indent=-4) self.block.code(op='}', indent=-1, semi=False)
self.block.pop_ctx() self.block.pop_ctx()
class _CallChain(object): class _CallChain(object):
@ -248,12 +309,12 @@ class _CallChain(object):
self.__chain = [] self.__chain = []
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
assert(self.__chain) assert(self.__chain)
self._call(chain, *args, **kwargs) self._call(self.__chain, *args, **kwargs)
self.__chain = [] self.__chain = []
def __getattr__(self, name): def __getattr__(self, name):
if name == 'global_': if name == 'global_':
name = 'global' name = 'global'
self.chain.append(name) self.__chain.append(name)
# Another great crime against the universe: # Another great crime against the universe:
return self return self
@ -284,7 +345,7 @@ class _RegFactory(_CallChain):
type = type[0] type = type[0]
names = names.split() names = names.split()
regs = map(lambda n: Reg(type, n), names) regs = map(lambda n: Reg(type, n), names)
self.block.code(op='.reg .' + type, vars=names) self.block.code(op='.reg .' + type, vars=_softjoin(names, ', '))
[self.block.inject(r.name, r) for r in regs] [self.block.inject(r.name, r) for r in regs]
# Pending resolution of the op(regs, guard=x) debate # Pending resolution of the op(regs, guard=x) debate
@ -318,8 +379,6 @@ class _RegFactory(_CallChain):
#self.name = name #self.name = name
#def is_set(self, isnot=False): #def is_set(self, isnot=False):
class Op(_CallChain): class Op(_CallChain):
""" """
Performs an operation. Performs an operation.
@ -340,15 +399,15 @@ class Op(_CallChain):
This constructor is available as 'op' in DSL blocks. This constructor is available as 'op' in DSL blocks.
""" """
def _call(self, op, *args, ifp=None, ifnotp=None): def _call(self, op, *args, **kwargs):
pred = '' pred = ''
if ifp: if 'ifp' in kwargs:
if ifnotp: if 'ifnotp' in kwargs:
raise SyntaxError("can't use both, fool") raise SyntaxError("can't use both, fool")
pred = ['@', ifp] pred = ['@', kwargs['ifp']]
if ifnotp: if 'ifnotp' in kwargs:
pred = ['@!', ifnotp] pred = ['@!', kwargs['ifnotp']]
self.block.append_code(pred, '.'.join(op), map(str, args)) self.block.code(pred, '.'.join(op), _softjoin(args, ', '))
class Mem(object): class Mem(object):
""" """
@ -381,7 +440,7 @@ class Mem(object):
>>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg)) >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg))
""" """
return ['{', [(a, ', ') for a in args][:-1], '}'] return ['{', _softjoin(args, ', '), '}']
@staticmethod @staticmethod
def addr(areg, aoffset=''): def addr(areg, aoffset=''):
@ -397,8 +456,7 @@ class _MemFactory(_CallChain):
"""Actual `mem` object""" """Actual `mem` object"""
def _call(self, type, name, array=False, initializer=None): def _call(self, type, name, array=False, initializer=None):
assert len(type) == 2 assert len(type) == 2
memobj = Mem(type, name, array) memobj = Mem(type, name, array, initializer)
self.dsl.inject(name, memobj)
if array is True: if array is True:
array = ['[]'] array = ['[]']
elif array: elif array:
@ -407,11 +465,12 @@ class _MemFactory(_CallChain):
array = [] array = []
if initializer: if initializer:
array += [' = ', initializer] array += [' = ', initializer]
self.block.code(op=['.%s.%s ' % type, name, array]) self.block.code(op=['.%s.%s ' % (type[0], type[1]), name, array])
self.block.inject(name, memobj)
class Label(object): class Label(object):
""" """
Specifies the target for a branch. Scoped in PTX? TODO: test. Specifies the target for a branch. Scoped in PTX? TODO: test that it is.
>>> label('infinite_loop') >>> label('infinite_loop')
>>> op.bra.uni('label') >>> op.bra.uni('label')
@ -426,25 +485,7 @@ class _LabelFactory(object):
self.block = block self.block = block
def __call__(self, name): def __call__(self, name):
self.block.inject(name, Label(name)) self.block.inject(name, Label(name))
self.block.code(prefix='%s:' % name, semi=False)
class PTXFragment(object):
def module_setup(self):
pass
def entry_setup(self):
pass
def entry_teardown(self):
pass
def globals(self):
pass
def tests(self):
pass
def device_init(self, ctx):
pass
class PTXFragment(object): class PTXFragment(object):
""" """
@ -470,12 +511,14 @@ class PTXFragment(object):
for successful compilation. Circular dependencies are forbidden, for successful compilation. Circular dependencies are forbidden,
but multi-level dependencies should be fine. but multi-level dependencies should be fine.
""" """
return [DeviceHelpers] return [_PTXStdLib]
def inject(self): def to_inject(self):
""" """
Returns a dict of items to add to the DSL namespace. The namespace will Returns a dict of items to add to the DSL namespace. The namespace will
be assembled in dependency order before any ptx_funcs are called. be assembled in dependency order before any ptx_funcs are called.
This is only called once per PTXModule (== once per instance).
""" """
return {} return {}
@ -483,8 +526,8 @@ class PTXFragment(object):
""" """
PTX function to declare things at module scope. It's a PTX syntax error PTX function to declare things at module scope. It's a PTX syntax error
to perform operations at this scope, but we don't yet validate that at to perform operations at this scope, but we don't yet validate that at
the Python level. A module will call this function on all fragments in the Python level. A module will call this function on all fragments
dependency order. used in that module in dependency order.
If implemented, this function should use an @ptx_func decorator. If implemented, this function should use an @ptx_func decorator.
""" """
@ -513,131 +556,47 @@ class PTXFragment(object):
""" """
pass pass
def tests(self, ctx): def finalize_code(self):
""" """
Returns a list of PTXTest classes which will test this fragment. Called after running all PTX DSL functions, but before code generation,
to allow fragments which postponed variable evaluation (e.g. using
`StrVar`) to fill in the resulting values. Most fragments should not
use this.
If implemented, this function *may* use an @ptx_func decorator to
access the global DSL scope, but pretty please don't emit any code
while you're in there.
"""
pass
def tests(self):
"""
Returns a list of PTXTest types which will test this fragment.
""" """
return [] return []
def set_up(self, ctx): def device_init(self, ctx):
""" """
Do start-of-stream initialization, such as copying data to the device. Do stuff on the host to prepare the device for execution. 'ctx' is a
LaunchContext or similar. This will get called (in dependency order, of
course) *either* before any entry point invocation, or before *each*
invocation, I'm not sure which yet. (For now it's "each".)
""" """
pass pass
class PTXModule(object):
"""
Assembles PTX fragments into a module.
"""
def __init__(self, entries, inject={}, build_tests=False):
self._block = b = _Block()
self.initial_inject = dict(inject)
self._safeupdate(self.initial_inject, dict(block=Block(b),
mem=_MemFactory(b), reg=_RegFactory(b), op=Op(b),
label=_LabelFactory(b), _block=b)
self.needs_recompilation = True
self.max_compiles = 10
while self.needs_recompilation:
self.assemble(entries, build_tests)
self.max_compiles -= 1
def deporder(self, unsorted_instances, instance_map, ctx):
"""
Do a DFS on PTXFragment dependencies, and return an ordered list of
instances where no fragment depends on any before it in the list.
`unsorted_instances` is the list of instances to sort.
`instance_map` is a dict of types to instances.
"""
seen = {}
def rec(inst):
if inst in seen: return seen[inst]
deps = filter(lambda d: d is not inst, map(instance_map.get,
callable(inst.deps) and inst.deps(self) or inst.deps))
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
"""dst.update(src), but no duplicates allowed"""
non_uniq = [k for k in src if k in dst]
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
dst.update(src)
def assemble(self, entries, build_tests):
"""
Build the PTX source for the given set of entries.
"""
# Get a property, dealing with the callable-or-data thing. This is
# cumbersome, but flexible; when finished, it may be simplified.
def pget(prop):
if callable(prop): return prop(ctx)
return prop
instances = {}
unvisited_entries = list(entries)
entry_names = {}
tests = []
parsed_entries = []
while unvisited_entries:
ent = unvisited_entries.pop(0)
seen, unvisited = set(), [ent]
while unvisited:
frag = unvisited.pop(0)
seen.add(frag)
inst = instances.setdefault(frag, frag())
for dep in pget(inst.deps):
if dep not in seen:
unvisited.append(dep)
if build_tests:
for test in pget(inst.tests):
if test not in tests:
if test not in instances:
unvisited_entries.append(test)
tests.append(test)
tmpl_namespace = {'ctx': ctx}
entry_start, entry_end = [], []
for inst in self.deporder(map(instances.get, seen), instances, ctx):
self._safeupdate(tmpl_namespace, pget(inst.subs))
entry_start.append(pget(inst.entry_start))
entry_end.append(pget(inst.entry_end))
entry_start_tmpl = '\n'.join(filter(None, entry_start))
entry_end_tmpl = '\n'.join(filter(None, reversed(entry_end)))
name, args, body = pget(instances[ent].entry)
tmpl_namespace.update({'_entry_name_': name, '_entry_args_': args,
'_entry_body_': body, '_entry_start_': entry_start_tmpl,
'_entry_end_': entry_end_tmpl})
entry_tmpl = (".entry {{ _entry_name_ }} ({{ _entry_args_ }})\n"
"{\n{{_entry_start_}}\n{{_entry_body_}}\n{{_entry_end_}}\n}\n")
parsed_entries.append(multisub(entry_tmpl, tmpl_namespace))
entry_names[ent] = name
prelude = []
tmpl_namespace = {'ctx': ctx}
for inst in self.deporder(instances.values(), instances, ctx):
prelude.append(pget(inst.prelude))
self._safeupdate(tmpl_namespace, pget(inst.subs))
tmpl_namespace['_prelude_'] = '\n'.join(filter(None, prelude))
tmpl_namespace['_entries_'] = '\n\n'.join(parsed_entries)
tmpl = "{{ _prelude_ }}\n{{ _entries_ }}"
self.entry_names = entry_names
self.source = ppr_ptx(multisub(tmpl, tmpl_namespace))
self.instances = instances
self.tests = tests
class PTXEntryPoint(PTXFragment): class PTXEntryPoint(PTXFragment):
# Human-readable entry point name # Human-readable entry point name
name = "" name = ""
# Device code entry name
entry_name = ""
# List of (type, name) pairs for entry params, e.g. [('u32', 'thing')]
entry_params = []
def entry(self, ctx): def entry(self, ctx):
""" """
Returns a 3-tuple of (name, args, body), which will be assembled into PTX DSL function that comprises the body of the PTX statement.
a function.
Must be implemented and decorated with ptx_func.
""" """
raise NotImplementedError raise NotImplementedError
@ -660,32 +619,216 @@ class PTXTest(PTXEntryPoint):
""" """
pass pass
class DeviceHelpers(PTXFragment): class _PTXStdLib(PTXFragment):
def __init__(self): def __init__(self, block):
self._forstack = [] # Only module that gets the privilege of seeing 'block' directly.
self.block = block
prelude = ".version 2.1\n.target sm_20\n\n" def deps(self):
return []
@ptx_func
def module_setup(self):
# TODO: make this modular, maybe? of course, we'd have to support
# multiple devices first, which we definitely do not yet do
self.block.code(prefix='.version 2.1', semi=False)
self.block.code(prefix='.target sm_20', semi=False)
@ptx_func
def _get_gtid(self, dst): def _get_gtid(self, dst):
return "{\n// Load GTID into " + dst + """ with block("Load GTID into %s" % str(dst)):
.reg .u16 tmp; reg.u16('tmp')
.reg .u32 cta, ncta, tid, gtid; reg.u32('cta ncta tid gtid')
mov.u16 tmp, %ctaid.x; op.mov.u16(tmp, '%ctaid.x')
cvt.u32.u16 cta, tmp; op.cvt.u32.u16(cta, tmp)
mov.u16 tmp, %ntid.x; op.mov.u16(tmp, '%ntid.x')
cvt.u32.u16 ncta, tmp; op.cvt.u32.u16(ncta, tmp)
mul.lo.u32 gtid, cta, ncta; op.mul.lo.u32(gtid, cta, ncta)
mov.u16 tmp, %tid.x; op.mov.u16(tmp, '%tid.x')
cvt.u32.u16 tid, tmp; op.cvt.u32.u16(tid, tmp)
add.u32 gtid, gtid, tid; op.add.u32(gtid, gtid, tid)
mov.b32 """ + dst + ", gtid;\n}" op.mov.b32(dst, gtid)
def subs(self, ctx): def to_inject(self):
return { return dict(
'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32', _block=self.block,
'get_gtid': self._get_gtid block=Block(self.block),
} op=Op(self.block),
reg=_RegFactory(self.block),
mem=_MemFactory(self.block),
addr=Mem.addr,
vec=Mem.vec,
label=_LabelFactory(self.block),
get_gtid=self._get_gtid)
class PTXModule(object):
"""
Assembles PTX fragments into a module. The following properties are
available:
`instances`: Mapping of type to instance for the PTXFragments used in
the creation of this PTXModule.
`entries`: List of PTXEntry types in this module, including any tests.
`tests`: List of PTXTest types in this module.
`source`: PTX source code for this module.
"""
max_compiles = 10
def __init__(self, entries, inject={}, build_tests=False, formatter=None):
"""
Construct a PTXModule.
`entries`: List of PTXEntry types to include in this module.
`inject`: Dict of items to inject into the DSL namespace.
`build_tests`: If true, build tests into the module.
`formatter`: PTXFormatter instance, or None to use defaults.
"""
block = _Block()
insts, tests, all_deps, entry_deps = (
self.deptrace(block, entries, build_tests))
self.instances = insts
self.tests = tests
inject = dict(inject)
self._safeupdate(inject, {'module': self})
for inst in all_deps:
self._safeupdate(inject, inst.to_inject())
[block.inject(k, v) for k, v in inject.items()]
self.__needs_recompilation = True
self.compiles = 0
while self.__needs_recompilation:
self.compiles += 1
self.__needs_recompilation = False
self.assemble(block, all_deps, entry_deps)
self.instances.pop(_PTXStdLib)
print self.instances
if not formatter:
formatter = PTXFormatter()
self.source = formatter.format(block.outer_ctx.code)
self.entries = list(set(entries + tests))
def deporder(self, unsorted_instances, instance_map):
"""
Do a DFS on PTXFragment dependencies, and return an ordered list of
instances where no fragment depends on any before it in the list.
`unsorted_instances` is the list of instances to sort.
`instance_map` is a dict of types to instances.
"""
seen = {}
def rec(inst):
if inst in seen: return seen[inst]
if inst is None: return 0
deps = filter(lambda d: d is not inst,
map(instance_map.get, inst.deps()))
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
"""dst.update(src), but no duplicates allowed"""
non_uniq = [k for k in src if k in dst]
if non_uniq: raise KeyError("Duplicate keys: %s" % ','.join(key))
dst.update(src)
def deptrace(self, block, entries, build_tests):
instances = {_PTXStdLib: _PTXStdLib(block)}
unvisited_entries = list(entries)
tests = set()
entry_deps = {}
# For each PTXEntry or PTXTest, use a BFS to recursively find and
# instantiate all fragments that are dependencies. If tests are
# discovered, add those to the list of entries.
while unvisited_entries:
ent = unvisited_entries.pop(0)
seen, unvisited = set(), [ent]
while unvisited:
frag = unvisited.pop(0)
seen.add(frag)
# setdefault doesn't work because of _PTXStdLib
if frag not in instances:
inst = frag()
instances[frag] = inst
else:
inst = instances[frag]
for dep in inst.deps():
if dep not in seen:
unvisited.append(dep)
if build_tests:
for test in inst.tests():
if test not in tests:
tests.add(test)
if test not in instances:
unvisisted_entries.append(tests)
# For this entry, store insts of all dependencies in order.
entry_deps[ent] = self.deporder(map(instances.get, seen),
instances)
# Find the order for all dependencies in the program.
all_deps = self.deporder(instances.values(), instances)
return instances, sorted(tests, key=str), all_deps, entry_deps
def assemble(self, block, all_deps, entry_deps):
# Rebind to local namespace to allow proper retrieval
_block = block
for inst in all_deps:
inst.module_setup()
for ent, insts in entry_deps.items():
# This is kind of hackish compared to everything else
params = [Reg('.param.' + str(type), name)
for (type, name) in ent.entry_params]
_block.code(op='.entry %s ' % ent.entry_name, semi=False,
vars=['(', ['%s %s' % (r.type, r.name) for r in params], ')'])
with Block(_block):
[_block.inject(r.name, r) for r in params]
for dep in insts:
dep.entry_setup()
self.instances[ent].entry()
for dep in reversed(insts):
dep.entry_teardown()
for inst in all_deps:
inst.finalize_code()
def set_needs_recompilation(self):
if not self.__needs_recompilation:
if self.compiles >= self.max_compiles:
raise ValueError("Too many recompiles scheduled!")
self.__needs_recompilation = True
class PTXFormatter(object):
"""
Formats PTXStmt items into beautiful code. Well, the beautiful part is
postponed for now.
"""
def __init__(self, indent=4):
self.indent_amt = 4
def _flatten(self, val):
if isinstance(val, (list, tuple)):
return ''.join(map(self._flatten, val))
return str(val)
def format(self, code):
out = []
indent = 0
for (pfx, op, vars, semi, indent_change) in code:
pfx = self._flatten(pfx)
op = self._flatten(op)
vars = map(self._flatten, vars)
if indent_change < 0:
indent = max(0, indent + self.indent_amt * indent_change)
# TODO: make this a lot prettier
line = ((('%%-%ds' % indent) % pfx) + op + ' ' + ''.join(vars))
if semi:
line = line.rstrip() + ';'
out.append(line)
if indent_change > 0:
indent += self.indent_amt * indent_change
return '\n'.join(out)

10
main.py
View File

@ -15,16 +15,16 @@ from ctypes import *
import numpy as np import numpy as np
#from cuburnlib.device_code import MWCRNGTest from cuburnlib.device_code import MWCRNGTest
#from cuburnlib.cuda import LaunchContext from cuburnlib.cuda import LaunchContext
from fr0stlib.pyflam3 import * from fr0stlib.pyflam3 import *
from fr0stlib.pyflam3._flam3 import * from fr0stlib.pyflam3._flam3 import *
from cuburnlib.render import * from cuburnlib.render import *
def main(genome_path): def main(genome_path):
#ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True) ctx = LaunchContext([MWCRNGTest], block=(256,1,1), grid=(64,1), tests=True)
#ctx.compile(True) ctx.compile(verbose=True)
#ctx.run_tests() ctx.run_tests()
with open(genome_path) as fp: with open(genome_path) as fp:
genomes = Genome.from_string(fp.read()) genomes = Genome.from_string(fp.read())