From c0e3c1d599afd53eb592b18bdcda76e650a09c2c Mon Sep 17 00:00:00 2001 From: Steven Robertson Date: Fri, 1 Oct 2010 01:20:20 -0400 Subject: [PATCH] Known broken checkin because I'm nervous. --- cuburn/__init__.py | 32 + cuburn/cuda.py | 73 +- cuburn/device_code.py | 249 +++---- cuburn/ptx.py | 1594 ++++++++++++++++------------------------- main.py | 2 + 5 files changed, 786 insertions(+), 1164 deletions(-) diff --git a/cuburn/__init__.py b/cuburn/__init__.py index e69de29..cb8e8ca 100644 --- a/cuburn/__init__.py +++ b/cuburn/__init__.py @@ -0,0 +1,32 @@ + +from collections import namedtuple + +Flag = namedtuple('Flag', 'level desc') + +class DebugSettings(object): + """ + Container for default debug settings. + """ + def __init__(self, items): + self.items = items + self.values = {} + self.level = 1 + def __getattr__(self, name): + if name not in self.items: + raise KeyError("Unknown debug flag name!") + if name in self.values: + return self.values[name] + return (self.items[name].level <= self.level) + def format_help(self): + name_len = min(30, max(map(len, self.items.keys()))) + fmt = '%-' + name_len + 's %d %s' + return '\n'.join([fmt % (k, v.level, v.desc) + for k, v in self.items.items()]) + +debug_flags = dict( + count_writes = Flag(3, "Count the number of points written per thread " + "when doing iterations."), + count_rounds = Flag(3, "Count the number of times the iteration loop " + "runs per thread when doing iterations.") + ) + diff --git a/cuburn/cuda.py b/cuburn/cuda.py index 9be816e..275394b 100644 --- a/cuburn/cuda.py +++ b/cuburn/cuda.py @@ -1,39 +1,48 @@ # These imports are order-sensitive! -import pyglet -import pyglet.gl as gl -gl.get_current_context() +#import pyglet +#import pyglet.gl as gl +#gl.get_current_context() import pycuda.driver as cuda from pycuda.compiler import SourceModule import pycuda.tools -import pycuda.gl as cudagl -import pycuda.gl.autoinit +#import pycuda.gl as cudagl +#import pycuda.gl.autoinit +import pycuda.autoinit import numpy as np -from cuburn.ptx import PTXModule, PTXTest, PTXTestFailure +from cuburn.ptx import PTXFormatter + +class Module(object): + def __init__(self, entries): + self.entries = entries + self.source = self.compile(entries) + self.mod = self.assemble(self.source) + + @staticmethod + def compile(entries): + formatter = PTXFormatter() + for entry in entries: + entry.format_source(formatter) + return formatter.get_source() + + def assemble(self, src): + # TODO: make this a debugging option + with open('/tmp/cuburn.ptx', 'w') as f: f.write(src) + try: + mod = cuda.module_from_buffer(src, + [(cuda.jit_option.OPTIMIZATION_LEVEL, 0), + (cuda.jit_option.TARGET_FROM_CUCONTEXT, 1)]) + except (cuda.CompileError, cuda.RuntimeError), e: + # TODO: if output not written above, print different message + # TODO: read assembler output and recover Python source lines + print "Compile error. Source is at /tmp/cuburn.ptx" + print e + raise e + return mod class LaunchContext(object): - """ - Context collecting the information needed to create, run, and gather the - results of a device computation. This may eventually also include an actual - CUDA context, but for now it just uses the global one. - - To create the fastest device code across multiple device families, this - context may decide to iteratively refine the final PTX by regenerating - and recompiling it several times to optimize certain parameters of the - launch, such as the distribution of threads throughout the device. - The properties of this device which are tuned are listed below. Any PTX - fragments which use this information must emit valid PTX for any state - given below, but the PTX is only required to actually run with the final, - fixed values of all tuned parameters below. - - `block`: 3-tuple of (x,y,z); dimensions of each CTA. - `grid`: 2-tuple of (x,y); dimensions of the grid of CTAs. - `nthreads`: Number of active threads on device as a whole. - `mod`: Final compiled module. Unavailable during assembly. - - """ def __init__(self, entries, block=(1,1,1), grid=(1,1), tests=False): self.entry_types = entries self.block, self.grid, self.build_tests = block, grid, tests @@ -60,18 +69,6 @@ class LaunchContext(object): kwargs['ctx'] = self self.ptx = PTXModule(self.entry_types, kwargs, self.build_tests) # TODO: make this optional and let user choose path - with open('/tmp/cuburn.ptx', 'w') as f: f.write(self.ptx.source) - try: - # TODO: detect/customize arch, code; verbose setting; - # keep directory enable/disable via debug - self.mod = cuda.module_from_buffer(self.ptx.source, - [(cuda.jit_option.OPTIMIZATION_LEVEL, 0), - (cuda.jit_option.TARGET_FROM_CUCONTEXT, 1)]) - except (cuda.CompileError, cuda.RuntimeError), e: - # TODO: if output not written above, print different message - print "Compile error. Source is at /tmp/cuburn.ptx" - print e - raise e if verbose: for entry in self.ptx.entries: func = self.mod.get_function(entry.entry_name) diff --git a/cuburn/device_code.py b/cuburn/device_code.py index b9e590e..8523d87 100644 --- a/cuburn/device_code.py +++ b/cuburn/device_code.py @@ -523,175 +523,130 @@ class ShufflePoints(PTXFragment): op.bar.sync(bar) op.ld.volatile.shared.b32(var, addr(shuf_read)) -class MWCRNG(PTXFragment): - shortname = "mwc" - - def __init__(self): - self.threads_ready = 0 +class MWCRNG(object): + def __init__(self, entry, seed=None): + # TODO: install this in data directory or something if not os.path.isfile('primes.bin'): raise EnvironmentError('primes.bin not found') + self.threads_ready = 0 + self.mults, self.state = None, None - @ptx_func - def module_setup(self): - mem.global_.u32('mwc_rng_mults', ctx.nthreads) - mem.global_.u64('mwc_rng_state', ctx.nthreads) + self.entry = entry + entry.add_param('mwc_mults', entry.types.u32) + entry.add_param('mwc_states', entry.types.u32) + r, o = entry.regs, entry.ops + with entry.head as e: + #mwc_mult_addr = gtid * 4 + e.params.mwc_mults + gtid = o.mad.lo(e.special.ctaid_x, ctx.threads_per_cta, + e.special.tid_x) + mwc_mult_addr = o.mad.lo.u32(gtid, 4, e.params.mwc_mults) + r.mwc_mult = o.load.u32(mwc_mult_addr) + mwc_state_addr = o.mad.lo.u32(gtid, 8, e.params.mwc_states) + r.mwc_state, r.mwc_carry = o.load.u64(mwc_state_addr) + with entry.tail as e: + #gtid = e.special.ctaid_x * ctx.threads_per_cta + e.special.tid_x + gtid = o.mad.lo(e.special.ctaid_x, ctx.threads_per_cta, + e.special.tid_x) + mwc_state_addr = o.mad.lo.u32(gtid, 8, e.params.mwc_states) + o.store.v2(mwc_state_addr, (r.mwc_state, r.mwc_carry)) - @ptx_func - def entry_setup(self): - reg.u32('mwc_st mwc_mult mwc_car') - with block('Load MWC multipliers and states'): - reg.u32('mwc_off mwc_addr') - std.get_gtid(mwc_off) - op.mov.u32(mwc_addr, mwc_rng_mults) - op.mad.lo.u32(mwc_addr, mwc_off, 4, mwc_addr) - op.ld.global_.u32(mwc_mult, addr(mwc_addr)) + def next_b32(self): + e, r, o = self.entry, self.entry.regs, self.entry.ops + mwc_out = o.cvt.u64(r.mwc_carry) + mwc_out = o.mad.wide.u32(r.mwc_mult, r.mwc_state, mwc_out) + r.mwc_state, r.mwc_carry = o.mov(mwc_out) + return r.mwc_state - op.mov.u32(mwc_addr, mwc_rng_state) - op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) - op.ld.global_.v2.u32(vec(mwc_st, mwc_car), addr(mwc_addr)) + def next_f32_01(self): + e, r, o = self.entry, self.entry.regs, self.entry.ops + mwc_float = o.cvt.rn.f32.u32(self.next_b32()) + # TODO: check the precision on the uploaded types here + return o.mul.f32(mwc_float, 1./(1<<32)) - @ptx_func - def entry_teardown(self): - with block('Save MWC states'): - reg.u32('mwc_off mwc_addr') - std.get_gtid(mwc_off) - op.mov.u32(mwc_addr, mwc_rng_state) - op.mad.lo.u32(mwc_addr, mwc_off, 8, mwc_addr) - op.st.global_.v2.u32(addr(mwc_addr), vec(mwc_st, mwc_car)) + def next_f32_11(self): + e, r, o = self.entry, self.entry.regs, self.entry.ops + mwc_float = o.cvt.rn.f32.s32(self.next_b32()) + return o.mul.f32(mwc_float, 1./(1<<31)) - @ptx_func - def _next(self): - # Call from inside a block! - reg.u64('mwc_out') - op.cvt.u64.u32(mwc_out, mwc_car) - op.mad.wide.u32(mwc_out, mwc_st, mwc_mult, mwc_out) - op.mov.b64(vec(mwc_st, mwc_car), mwc_out) - - @ptx_func - def next_b32(self, dst_reg): - with block('Load next random u32 into ' + dst_reg.name): - self._next() - op.mov.u32(dst_reg, mwc_st) - - @ptx_func - def next_f32_01(self, dst_reg): - # TODO: verify that this is the fastest-performance method - # TODO: verify that this actually does what I think it does - with block('Load random float [0,1] into ' + dst_reg.name): - self._next() - op.cvt.rn.f32.u32(dst_reg, mwc_st) - op.mul.f32(dst_reg, dst_reg, '0f2F800000') # 1./(1<<32) - - @ptx_func - def next_f32_11(self, dst_reg): - with block('Load random float [-1,1) into ' + dst_reg.name): - reg.u32('mwc_to_float') - self._next() - op.cvt.rn.f32.s32(dst_reg, mwc_st) - op.mul.f32(dst_reg, dst_reg, '0f30000000') # 1./(1<<31) - - @instmethod - def seed(self, ctx, rand=np.random): + def call_setup(self, ctx, force=False): """ Seed the random number generators with values taken from a ``np.random`` instance. """ - # Load raw big-endian u32 multipliers from primes.bin. - with open('primes.bin') as primefp: - dt = np.dtype(np.uint32).newbyteorder('B') - mults = np.frombuffer(primefp.read(), dtype=dt) - stream = cuda.Stream() - # Randomness in choosing multipliers is good, but larger multipliers - # have longer periods, which is also good. This is a compromise. - mults = np.array(mults[:ctx.nthreads*4]) - rand.shuffle(mults) - # Copy multipliers and seeds to the device - multdp, multl = ctx.mod.get_global('mwc_rng_mults') - cuda.memcpy_htod(multdp, mults.tostring()[:multl]) - # Intentionally excludes both 0 and (2^32-1), as they can lead to - # degenerate sequences of period 0 - states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.nthreads), - dtype=np.uint32) - statedp, statel = ctx.mod.get_global('mwc_rng_state') - cuda.memcpy_htod(statedp, states.tostring()) - self.threads_ready = ctx.nthreads + if force or self.nthreads_ready < ctx.nthreads: + # Load raw big-endian u32 multipliers from primes.bin. + with open('primes.bin') as primefp: + dt = np.dtype(np.uint32).newbyteorder('B') + mults = np.frombuffer(primefp.read(), dtype=dt) + # Randomness in choosing multipliers is good, but larger multipliers + # have longer periods, which is also good. This is a compromise. + mults = np.array(mults[:ctx.nthreads*4]) + rand.shuffle(mults) + locked_mults = ctx.hostpool.allocate(ctx.nthreads, np.uint32) + locked_mults[:] = mults[ctx.nthreads] + self.mults = ctx.pool.allocate(4*ctx.nthreads) + cuda.memcpy_htod_async(self.mults, locked_mults.base, ctx.stream) + # Intentionally excludes both 0 and (2^32-1), as they can lead to + # degenerate sequences of period 0 + states = np.array(rand.randint(1, 0xffffffff, size=2*ctx.nthreads), + dtype=np.uint32) + locked_states = ctx.hostpool.allocate(2*ctx.nthreads, np.uint32) + locked_states[:] = states + self.states = ctx.pool.allocate(8*ctx.nthreads) + cuda.memcpy_htod_async(self.states, locked_states, ctx.stream) + self.nthreads_ready = ctx.nthreads + ctx.set_param('mwc_mults', self.mults) + ctx.set_param('mwc_states', self.states) - def call_setup(self, ctx): - if self.threads_ready < ctx.nthreads: - self.seed(ctx) - - def tests(self): - return [MWCRNGTest, MWCRNGFloatsTest] - -class MWCRNGTest(PTXTest): - name = "MWC RNG sum-of-threads" +class MWCRNGTest(PTXEntry): rounds = 5000 - entry_name = 'MWC_RNG_test' - entry_params = '' - def deps(self): - return [MWCRNG] + def __init__(self, entry): + self.entry = entry + self.mwc = MWCRNG(entry) - @ptx_func - def module_setup(self): - mem.global_.u64('mwc_rng_test_sums', ctx.nthreads) + entry.add_param('mwc_test_sums', entry.types.u32) + with entry.body(): + self.entry_body() - @ptx_func - def entry(self): - reg.u64('sum addl') - reg.u32('addend') - op.mov.u64(sum, 0) - with block('Sum next %d random numbers' % self.rounds): - reg.u32('loopct') - reg.pred('p') - op.mov.u32(loopct, self.rounds) - label('loopstart') - mwc.next_b32(addend) - op.cvt.u64.u32(addl, addend) - op.add.u64(sum, sum, addl) - op.sub.u32(loopct, loopct, 1) - op.setp.gt.u32(p, loopct, 0) - op.bra.uni(loopstart, ifp=p) + def entry_body(self): + e, r, o = self.entry, self.entry.regs, self.entry.ops - with block('Store sum and state'): - reg.u32('adr offset') - std.get_gtid(offset) - op.mov.u32(adr, mwc_rng_test_sums) - op.mad.lo.u32(adr, offset, 8, adr) - op.st.global_.u64(addr(adr), sum) + r.sum = 0 + with e.std.loop(self.rounds) as mwc_rng_sum: + addend = o.cvt.u64.u32(self.mwc.next_b32()) + r.sum = o.add.u64(r.sum, addend) - def call_setup(self, ctx): - # Get current multipliers and seeds from the device - self.mults = ctx.get_per_thread('mwc_rng_mults', np.uint32) - self.fullstates = ctx.get_per_thread('mwc_rng_state', np.uint64) - self.sums = np.zeros(ctx.nthreads, np.uint64) + e.std.store_per_thread(e.params.mwc_test_sums, r.sum) - print "Running %d states forward %d rounds" % \ - (len(self.mults), self.rounds) - ctime = time.time() - for i in range(self.rounds): - states = self.fullstates & 0xffffffff - carries = self.fullstates >> 32 - self.fullstates = self.mults * states + carries - self.sums += self.fullstates & 0xffffffff - ctime = time.time() - ctime - print "Done on host, took %g seconds" % ctime + def call(self, ctx): + # Generate current state, upload it to GPU + self.mwc.call_setup(ctx, force=True) + mults, fullstates = self.mwc.mults, self.mwc.fullstates + sums = np.zeros_like(fullstates) - def call_teardown(self, ctx): - dfullstates = ctx.get_per_thread('mwc_rng_state', np.uint64) - if not (dfullstates == self.fullstates).all(): - print "State discrepancy" - print dfullstates - print self.fullstates - raise PTXTestFailure("MWC RNG state discrepancy") + # Run two trials, to ensure device state is getting saved properly + for trial in range(2): + print "Trial %d, on CPU: " % trial, + ctime = time.time() + for i in range(self.rounds): + states = fullstates & 0xffffffff + carries = fullstates >> 32 + fullstates = self.mults * states + carries + sums += fullstates & 0xffffffff + ctime = time.time() - ctime + print "Took %g seconds." % ctime + print "Trial %d, on device: " % trial, + dsums = np.empty_like(sums) + ctx.set_param('mwc_test_sums', cuda.Out(dsums)) + print "Took %g seconds." % ctx.call() - dsums = ctx.get_per_thread('mwc_rng_test_sums', np.uint64) - if not (dsums == self.sums).all(): - print "Sum discrepancy" - print dsums - print self.sums - raise PTXTestFailure("MWC RNG sum discrepancy") + if not np.all(np.equal(sums, dsums)): + print "Sum discrepancy!" + print sums + print dsums + raise TODOSomeKindOfException() class MWCRNGFloatsTest(PTXTest): """ diff --git a/cuburn/ptx.py b/cuburn/ptx.py index b099554..6d81856 100644 --- a/cuburn/ptx.py +++ b/cuburn/ptx.py @@ -1,16 +1,9 @@ """ PTX DSL, a domain-specific language for NVIDIA's PTX. - -The DSL doesn't really provide any benefits over raw PTX in terms of type -safety or error checking. Where it shines is in enabling code reuse, -modularization, and dynamic data structures. In particular, the "data stream" -that controls the iterations and xforms in cuflame's device code are much -easier to maintain using this system. """ # If you see 'import inspect', you know you're in for a good time import inspect -import types import struct from cStringIO import StringIO from collections import namedtuple @@ -19,1028 +12,682 @@ from math import * import numpy as np import pycuda.driver as cuda -# Okay, so here's what's going on. -# -# We're using Python to create PTX. If we just use Python to make one giant PTX -# module, there's no real reason of going to the trouble of using Python to -# begin with, as the things that this system is good for - modularization, unit -# testing, automated analysis, and data structure generation and optimization - -# pretty much require splitting code up into manageable units. However, -# splitting things up at the level of PTX will greatly reduce performance, as -# the cost of accessing the stack, spilling registers, and reloading data from -# system memory is unacceptably high even on Fermi GPUs. So we want to split -# code up into functions within Python, but not within the PTX source. -# -# The challenge here is variable lifetime. A PTX function might declare a -# register at the top of the main block and use it several times throughout the -# function. In Python, we split that up into multiple functions, one to declare -# the registers at the start of the scope and another to make use of them later -# on. This makes it very easy to reuse a class of related PTX functions in -# different device entry points, do unit tests, and so on. -# -# The scope of the class instance is unrelated to the normal scope of names in -# Python. In fact, a function call frequently declares a register that may be -# later siblings in the call stack. So where to store the information regarding -# the register that was declared at the top of the file (name, type, etc)? -# Well, once declared, a variable remains in scope in PTX until the closing -# brace of the block (curly-braces segment) it was declared in. The natural -# place to store it would be in a Pythonic representation of the block: a block -# object that implements the context manager. -# -# This works well in terms of tracking object lifetime, but it adds a great -# deal of ugliness to the code. What I originally sought was this:: -# -# def load_zero(dest_reg): -# op.mov.u32(dest_reg, 0) -# def init_module(): -# reg.u32('hooray_reg') -# load_zero(hooray_reg) -# -# But using blocks alone to track names, it would turn in to this mess:: -# -# def load_zero(block, dest_reg): -# block.op.mov.u32(block.op.dest_reg, 0) -# def init_module(): -# with Block() as block: -# block.regs.hooray_reg = block.reg.u32('hooray_reg') -# load_zero(block, block.regs.hooray_reg) -# -# Eeugh. -# -# Anyway, never one to use an acceptable solution when an ill-conceived hack -# was available, I poked and prodded until I found a way to attain my ideal. -# In short, a function with a 'ptx_func' decorator will be wrapped in a -# _BlockInjector context manager, which will temporarily add values to the -# function's global dictionary in such a way as to mimic the desired behavior. -# The decorator is kind enough to pop the values when exiting. The examples -# below give a clear picture of how to use it, but now you know why this -# abomination was crafted to begin with. +from pprint import pprint -def _softjoin(args, sep): - """Intersperses 'sep' between 'args' without coercing to string.""" - return [[arg, sep] for arg in args[:-1]] + list(args[-1:]) +PTX_VERSION=(2, 1) -BlockCtx = namedtuple('BlockCtx', 'locals code injectors') -PTXStmt = namedtuple('PTXStmt', 'prefix op vars semi indent') +Type = namedtuple('Type', 'name kind bits bytes') +TYPES = {} +for kind in 'busf': + for width in [8, 16, 32, 64]: + TYPES[kind+str(width)] = Type(kind+str(width), kind, width, width / 8) +del TYPES['f8'] +TYPES['pred'] = Type('pred', 'p', 0, 0) -class _BlockInjector(object): +class Statement(object): """ - A ContextManager that, upon entering a context, loads some keys into a - dictionary, and upon leaving it, removes those keys. If any keys are - already in the destination dictionary with a different value, an exception - is raised. - - Useful if the destination dictionary is a func's __globals__. + Representation of a PTX statement. """ - def __init__(self, to_inject, inject_into): - self.to_inject, self.inject_into = to_inject, inject_into - self.injected = set() - self.dead = True - def inject(self, kv, v=None): - """Inject a key-value pair (passed either as a tuple or separately.)""" - k, v = v and (kv, v) or kv - if k not in self.to_inject: - self.to_inject[k] = v - if self.dead: - return - if k in self.inject_into: - if self.inject_into[k] is not v: - raise KeyError("Key with different value already in dest") - else: - self.inject_into[k] = v - self.injected.add(k) - def pop(self, keys): - """Remove keys from a dictionary, as long as we added them.""" - assert not self.dead - for k in keys: - if k in self.injected: - self.inject_into.pop(k) - self.injected.remove(k) - def __enter__(self): - self.dead = False - map(self.inject, self.to_inject.items()) - def __exit__(self, exc_type, exc_val, tb): - # Do some real exceptorin' - if exc_type is not None: return - for k in self.injected: - del self.inject_into[k] - self.dead = True + known_opnames = ('add addc sub subc mul mad mul24 mad24 sad div rem abs ' + 'neg min max popc clz bfind brev bfe bfi prmt testp copysign rcp ' + 'sqrt rsqrt sin cos lg2 ex2 set setp selp slct and or xor not ' + 'cnot shl shr mov ld ldu st prefetch prefetchu isspacep cvta cvt ' + 'tex txq suld sust sured suq bra call ret exit bar membar atom red ' + 'vote vadd vsub vabsdiff vmin vmax vshl vshr vmad vset').split() -class _Block(object): - """ - State-tracker for PTX fragments. You should really look at Block and - PTXModule instead of here. + def __init__(self, name, args, line_info = None): + self.opname = name + self.fullname, self.operands, self.rtype = self.parse(name, args) + self.result = None + self.python_line = line_info + self.ptx_line = None - For important reasons, the instance must be bound locally as "_block". - """ - name = '_block' # For retrieving from parent scope on first call - def __init__(self): - self.reset() - def reset(self): - self.outer_ctx = BlockCtx({self.name: self}, [], []) - self.stack = [self.outer_ctx] - def clean_injectors(self): - inj = self.stack[-1].injectors - [inj.remove(i) for i in list(inj) if i.dead] - def push_ctx(self): - self.clean_injectors() - self.stack.append(BlockCtx(dict(self.stack[-1].locals), [], [])) - # The only reason we should have no injectors in the previous block is - # if we are hitting a new ptx_func entry point or global declaration at - # PTX module scope, which means the stack only contains the outer - # context and the current one (i.e. len(stack) == 2) - if len(self.stack[-2].injectors) == 0: - assert len(self.stack) == 2, "Empty injector list too early!" - # Otherwise, the active injector in the previous block is the one for - # the Python function which is currently creating a new PTX block, and - # and it needs to be promoted to the current block - else: - self.stack[-1].injectors.append(self.stack[-2].injectors.pop()) - def pop_ctx(self): - self.clean_injectors() - bs = self.stack.pop() - self.stack[-1].code.extend(bs.code) - if len(self.stack) == 1: - # We're on outer_ctx, so all injectors should be gone. - assert len(bs.injectors) == 0, "Injector/context mismatch" - return - # The only injector should be the one added in push_ctx - assert len(bs.injectors) == 1, "Injector/context mismatch" - # Find out which keys were injected while in this context - diff = set(bs.locals.keys()).difference( - set(self.stack[-1].locals.keys())) - # Pop keys and move current injector back down to last context - last_inj = bs.injectors.pop() - last_inj.pop(diff) - self.stack[-1].injectors.append(last_inj) - def injector(self, func_globals): - inj = _BlockInjector(dict(self.stack[-1].locals), func_globals) - self.stack[-1].injectors.append(inj) - return inj - def inject(self, name, object): - if name in self.stack[-1].locals: - if self.stack[-1].locals[name] is not object: - raise KeyError("'%s' already exists in this scope." % name) - else: - self.stack[-1].locals[name] = object - [inj.inject(name, object) for inj in self.stack[-1].injectors] - def code(self, prefix='', op='', vars=[], semi=True, indent=0): + @staticmethod + def parse(name, args): """ - Append a PTX statement (or thereabouts) to the current block. + Parses and expands a (possibly incomplete) PTX statement, returning the + complete operation name and destination register type. - - `prefix`: a string which will not be indented, regardless of the - current indent level, for labels and predicates. - - `op`: a string, aligned to current indent level. - - `vars`: a list of strings, with best-effort alignment. - - `semi`: whether to terminate the current line with a semicolon. - - `indent`: integer adjustment to the current indent level. + ``name`` is a list of the parts of the operation name (as would be + given by ``'add.u32'.split()``, for example). + ``args`` is a list of the arguments to the operation, excluding the + destination register. - For `prefix`, `op`, and `vars`, a "string" can also mean a sequence of - objects that can be coerced to strings, which will be joined without - spacing. To keep things simple, nested lists and tuples will be reduced - in this manner (but not other iterable types). Coercion will not happen - until after the entire DSL call tree has been walked. This allows a - class to submit a mutable type (e.g. ``DelayVar``) when first - walked with an undefined value, then substitute the correct value on - being finalized. - - Details about alignment are available in the `PTXFormatter` class. And - yes, the only real difference between `prefix`, `op`, and `vars` is in - final appearance, but it is in fact quite helpful for debugging. + Returns a 3-tuple of ``(fullname, args, rtype)``, where ``fullname`` is + the fully-expanded name of the operation, ``args`` is the list of + arguments with all untyped values converted to ``Immediate`` values of + the appropriate type, and ``type`` is the expected result type of the + statement. If the statement does not have a destination register, + ``type`` will be None. """ - self.stack[-1].code.append(PTXStmt(prefix, op, vars, semi, indent)) + # TODO: .ftz insertion -class DelayVar(object): - """ - Trivial wrapper to allow deferred variable substitution. - """ - def __init__(self, val=None): - self.val = val - def __str__(self): - return str(self.val) - def __mul__(self, other): - # Oh this is truly egregious - return DelayVarProxy(self, "self.other.val*" + str(other)) + if name[0] in 'tex txq suld sust sured suq call'.split(): + raise NotImplementedError("No support for %s yet" % name[0]) -class DelayVarProxy(object): - def __init__(self, other, expr): - self.other, self.expr = other, expr - def __str__(self): - return str(eval(self.expr)) + # Make sure we don't modify the caller's list/tuple + name, args = list(name), list(args) -class _PTXFuncWrapper(object): - """Enables ptx_func""" - def __init__(self, func): - self.func = func - def __call__(self, *args, **kwargs): - if _Block.name in globals(): - block = globals()['block'] - else: - # Find the '_block' from the enclosing scope - parent = inspect.stack()[2][0] - if _Block.name in parent.f_locals: - block = parent.f_locals[_Block.name] - elif _Block.name in parent.f_globals: - block = parent.f_globals[_Block.name] + # Six constants that just have to be unique from each other + # 'stype', 'dtype', 'ignore', 'u32', 'pred', 'memory' + ST, DT, IG, U3, PR, ME = range(6) + + if name[0] in ('add addc sub subc mul mul24 div rem min max and or ' + 'xor not cnot copysign').split(): + atypes = [ST, ST] + elif name[0] in ('abs neg popc clz bfind brev testp rcp sqrt rsqrt sin ' + 'cos lg2 ex2 mov cvt cvta isspacep split').split(): + atypes = [ST] + elif name[0] == 'mad' and name[1] == 'wide': + atypes = [ST, ST, DT] + elif name[0] in 'mad mad24 sad'.split(): + atypes = [ST, ST, ST] + elif name[0] == 'bfe': + atypes = [ST, U3, U3] + elif name[0] == 'bfi': + atypes = [ST, ST, U3, U3] + elif name[0] == 'prmt': + atypes = [U3, U3, U3] + elif name[0] in 'ld ldu prefetch prefetchu': + atypes = [ME] + elif name[0] == 'st': + atypes = [ME, ST] + elif name[0] in 'set setp selp'.split(): + atypes = [ST, ST, IG] + elif name[0] == 'slct': + atypes = [DT, DT, ST] + elif name[0] in ('shl', 'shr'): + atypes = [ST, U3] + elif name[0] in ('atom', 'red'): + if name[1] == 'cas': + atypes = [ME, ST, ST] else: - # Couldn't find the _block instance. Fail cryptically to - # encourage users to read the source (for now) - raise SyntaxError("Black magic") - # Create a new function with the modified scope and call it. We could - # do this in __init__, but it would hide any changes to globals from - # the module's original scope. Still an option if performance sucks. - newglobals = dict(self.func.func_globals) - func = types.FunctionType(self.func.func_code, newglobals, - self.func.func_name, self.func.func_defaults, - self.func.func_closure) - with block.injector(func.func_globals): - func(*args, **kwargs) - -def ptx_func(func): - """ - Decorator function for code in the DSL. Any function which accesses the DSL - namespace, including declared device variables and objects such as "reg" - or "op", should be wrapped with this. See Block for some examples. - - Note that writes to global variables will silently fail for now. - """ - # Attach most of the code to the wrapper class - fw = _PTXFuncWrapper(func) - def wr(*args, **kwargs): - fw(*args, **kwargs) - return wr - -class Block(object): - """ - Limits the lifetime of variables in both PTX (using curly-braces) and in - the Python DSL (via black magic). This is semantically useful, but should - not otherwise affect device code (the lifetime of a register is - aggressively minimized by the compiler). - - >>> with block('This comment will appear at the top of the block'): - >>> reg.u32('same_name') - >>> with block(): - >>> reg.u64('same_name') # OK, because 'same_name' went out of scope - - PTX variables declared inside a block will be available in any other - ptx_func called within that block. Note that this flies in the face of - normal Python behavior! That's why it's a DSL. (This doesn't apply to - non-PTX variables.) - - >>> @ptx_func - >>> def fn1(): - >>> op.mov.u32(reg1, 0) - >>> - >>> @ptx_func - >>> def fn2(): - >>> print x - >>> - >>> @ptx_func - >>> def fn3(): - >>> with block(): - >>> reg.u32('reg1') - >>> x = 4 - >>> fn1() # OK: DSL magic propagates 'reg1' to fn1's namespace - >>> fn2() # FAIL: DSL magic doesn't touch regular variables - >>> fn1() # FAIL: 'reg1' went out of scope along with the block - - This constructor is available as 'block' in the DSL namespace. - """ - def __init__(self, block): - # `block` is the real _block - self.block = block - self.comment = None - def __call__(self, comment=None): - self.comment = comment - return self - def __enter__(self): - self.block.push_ctx() - self.block.code(op='{', semi=False) - self.block.code(indent=1) - if self.comment: - self.block.code(op=['// ', self.comment], semi=False) - self.comment = None - def __exit__(self, exc_type, exc_value, tb): - # Allow exceptions to be propagated; things get really messy if we try - # to pop the stack if things aren't ordered correctly - if exc_type is not None: return - self.block.code(indent=-1) - self.block.code(op='}', semi=False) - self.block.pop_ctx() - -class _CallChain(object): - """Handles the syntax for the operator chaining in PTX, like op.mul.u32.""" - def __init__(self, block): - self.block = block - self.__chain = [] - def __call__(self, *args, **kwargs): - assert(self.__chain) - r = self._call(self.__chain, *args, **kwargs) - self.__chain = [] - return r - def __getattr__(self, name): - # Work around keword conflicts between python and ptx - name = name.strip('_') - self.__chain.append(name) - # Another great crime against the universe: - return self - -class Reg(object): - """ - Creates one or more registers. The argument should be a string containing - one or more register names, separated by whitespace; the registers will be - injected into the DSL namespace on creation, so you do not need to - rebind them to the same name before use. - - >>> with block(): - >>> reg.u32('addend product') - >>> op.mov.u32(addend, 0) - >>> op.mov.u32(product, 0) - >>> op.mov.u32(addend, 1) # Fails, block unbinds globals on leaving scope - - This constructor is available as 'reg' in the DSL namespace. - """ - def __init__(self, type, name): - self.type, self.name = type, name - def __str__(self): - return self.name - -class _RegFactory(_CallChain): - """The actual 'reg' object in the DSL namespace.""" - def _call(self, type, names): - assert len(type) == 1 - type = type[0] - names = names.split() - regs = map(lambda n: Reg(type, n), names) - self.block.code(op='.reg .' + type, vars=_softjoin(names, ',')) - [self.block.inject(r.name, r) for r in regs] - if len(regs) == 1: - return regs[0] - return regs - -class Op(_CallChain): - """ - Performs an operation. - - >>> op.mov.u32(address, mwc_rng_test_sums) - >>> op.mad.lo.u32(address, offset, 8, address) - >>> op.st.global_.v2.u32(addr(address), vec(mwc_a, mwc_b)) - - To make an operation conditional on a predicate, use 'ifp' or 'ifnotp': - - >>> reg.pred('p1') - >>> op.setp.eq.u32(p1, reg1, reg2) - >>> op.mul.lo.u32(reg1, reg1, reg2, ifp=p1) - >>> op.add.u32(reg2, reg1, reg2, ifnotp=p1) - - Note that the global state-space should be written 'global_' to avoid - conflict with the Python keyword. `addr` and `vec` are defined in Mem. - - This constructor is available as 'op' in DSL blocks. - """ - def _call(self, op, *args, **kwargs): - pred = '' - ifp = kwargs.get('ifp') - ifnotp = kwargs.get('ifnotp') - if ifp: - if ifnotp: - raise SyntaxError("can't use both, fool") - pred = ['@', ifp] - if ifnotp: - pred = ['@!', ifnotp] - self.block.code(pred, '.'.join(op), _softjoin(args, ',')) - -class Mem(object): - """ - Reserve memory, optionally with an array size attached. - - >>> mem.global_.u32('global_scalar') - >>> mem.local.u32('context_sized_local_array', ctx.nthreads*4) - >>> mem.shared.u32('shared_array', 12) - >>> mem.const.u32('const_array_of_unknown_length', True) - - Like registers, memory allocations are injected into the global namespace - for use by any functions inside the scope without extra effort. - - >>> with block('move address into memory'): - >>> reg.u32('mem_address') - >>> op.mov.u32(mem_address, global_scalar) - - This constructor is available as 'mem' in DSL blocks. - """ - # Pretty much the same as 'Reg', duplicated only for clarity - def __init__(self, type, name, array, init): - self.type, self.name, self.array, self.init = type, name, array, init - def __str__(self): - return self.name - - @staticmethod - def vec(*args): - """ - Prepare vector arguments to a memory operation. - - >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg)) - """ - assert len(args) >= 2, "vector loads make no sense with < 2 args" - # TODO: fix the way this looks (not so easy) - return ['{', _softjoin(args, ','), '}'] - - @staticmethod - def addr(areg, aoffset=''): - """ - Prepare an address to a memory operation, optionally specifying offset. - - >>> op.st.global.v2.u32(addr(areg), vec(reg1, reg2)) - >>> op.ld.global.v2.u32(vec(reg1, reg2), addr(areg, 8)) - """ - return ['[', areg, aoffset is not '' and '+' or '', aoffset, ']'] - -class _MemFactory(_CallChain): - """Actual `mem` object""" - def _call(self, type, name, array=False, init=None): - memobj = Mem(type, name, array, init) - if array is True: - array = ['[]'] - elif array: - array = ['[', array, ']'] + atypes = [ME, ST] + elif name[0] in 'ret exit membar'.split(): + atypes = [] + elif name[0] == 'vote': + atypes = [PR] + elif name[0] in 'bar': + atypes = [U3, IG, IG] else: - array = [] - if init: - array += [' = ', init] - self.block.code(op='.%s ' % ' .'.join(type), vars=[name, array]) - self.block.inject(name, memobj) + raise NotImplementedError("Don't recognize the %s statement. " + "If you think this is a bug, and it may well be, please " + "report it!" % name[0]) - # TODO: move vec, addr here, or make this public + if (len(args) < len(filter(lambda t: t != IG, atypes)) or + len(args) > len(atypes)): + print args + print atypes + raise ValueError("Incorrect number of args for '%s'" % name[0]) + + stype, dtype = None, None + did_inference = False + + if isinstance(args[0], Pointer): + # Get stype from pointer (explicit stype overrides this) + if name[0] in 'ld ldu st'.split(): + stype = args[0].dtype + did_inference = True + # Get sspace from pointer if missing + if name[0] in 'ld ldu st prefetch atom red'.split(): + sspos = 2 if len(name) > 1 and name[1] == 'volatile' else 1 + if (len(name) <= sspos or name[sspos] not in + 'global local shared param const'.split()): + name.insert(sspos, args[0].sspace) + + # These instructions lack an stype suffix + if name[0] in ('prmt prefetch prefetchu isspacep bra ret exit membar ' + 'bar vote'.split()): + # False (as opposed to None) prevents stype inference attempt + stype = False + else: + # These instructions require a dtype + if name[0] in 'set slct cvt': + if name[-1] not in TYPES: + raise SyntaxError("'%s' requires a dtype." % name[0]) + if name[-2] in TYPES: + dtype, stype = TYPES[name[-2]], TYPES[name[-1]] + else: + dtype = TYPES[name[-1]] + else: + if name[-1] in TYPES: + stype = TYPES[name[-1]] + did_inference = False + + # stype wasn't explicitly set, try to infer it from the arguments + if stype is None: + maybe_typed = [a for a, t in zip(args, atypes) if t == ST] + types = [a.type for a in maybe_typed if isinstance(a, Register)] + if not types: + raise TypeError("Not enough information to infer type. " + "Explicitly specify the source argument type.") + stype = types[0] + did_inference = True + + if did_inference: + name.append(stype.name) + + # These instructions require a 'b32'-type argument, despite working + # on u32 and s32 types just fine, so change the name but not stype + if name[0] in 'popc clz bfind brev bfi and or xor not cnot shl'.split(): + name[-1] = 'b' + name[1:] + + # Calculate destination type (may influence some args too) + if (name[0] in 'popc clz bfind prmt'.split() or + name[:3] == ['bar', 'red', 'popc'] or + name[:2] == ['vote', 'ballot']): + dtype = TYPES['u32'] + elif (name[0] in 'testp setp isspacep vote'.split() or + name[:2] == ['bar', 'red']): + dtype = TYPES['pred'] + elif (name[0] in 'st prefetch prefetchu bra ret exit bar membar ' + 'red'.split()): + dtype = None + elif name[0] in ('mul', 'mad') and name[1] == 'wide': + dtype = TYPES[stype.kind + str(2*stype.bits)] + elif dtype is None: + dtype = stype + + atype_dict = {ST: stype, DT: dtype, U3: TYPES['u32']} + + # Wrap any untyped immediates + for idx, arg in enumerate(args): + if not isinstance(arg, Register): + t = atype_dict.get(atypes[idx]) + args[idx] = Immediate(None, t, arg) + + if did_inference: + for i, (arg, atype) in enumerate(zip(args, atypes)): + if atype in atype_dict and arg.type != atype_dict[atype]: + raise TypeError("Arg %d differs from expected type %s. " + "If this is intentional, explicitly specify the " + "source argument type." % (i, atype.name)) + if name[0] in 'ld ldu st red atom'.split(): + if (isinstance(args[0], Pointer) and + args[0].dtype.bits != stype.bits): + raise TypeError("The inferred type %s differs in size " + "from the referent's type %s. If this is intentional, " + "explicitly specify the source argument type." % + (stype.name, args[0].dtype.name)) + + return name, tuple(args), dtype + +class Register(object): + """ + The workhorse. + """ + def __init__(self, entry, type): + self.entry, self.type = entry, type + # Ordinary register naming / lifetime tracking + self.name, self.inferred_name, self.rebound_to = None, None, None + # Immediate value binding and other non-user-exposed hackery + self._ptx = None + + def _set_val(self, val): + if not isinstance(val, Register): + val = Immediate(self.entry, self.type, val) + self.entry.add_rebinding(self, val) + val = property(lambda s: s, _set_val) + def __repr__(self): + s = super(Register, self).__repr__()[:-1] + return s + ': type=%s, name=%s, inferred_name=%s>' % ( + self.type.name, self.name, self.inferred_name) + def get_name(self): + if self._ptx is not None: + return str(self._ptx) + if self.rebound_to: + return self.rebound_to.get_name() + return self.name or self.inferred_name + + def _infer_name(self, depth=2): + """ + To produce more readable code, this method reaches in to the stack and + tries to find the name of this register in the calling method's locals. + If a register is still unbound at code generation time, this name will + be preferred over a meaningless ``rXX``-style identifier. + + This best-guess effort should have absolutely no semantic impact on the + generated PTX, and is only here for readability, so we don't sweat the + potential corner cases associated with it. + + ``depth`` is the index of the relevant frame in this function's stack. + """ + if self.inferred_name is None: + frame = inspect.stack()[depth][0] + for key, val in frame.f_locals.items(): + if self is val: + self.inferred_name = key + break + +class Pointer(Register): + """ + A register which knows (in Python, at least) the type, state space, and + address of a datum in memory. + """ + # TODO: use u64 as type if device has >=4GB of memory + ptr_type = TYPES['u32'] + def __init__(self, entry, sspace, dtype): + super(Pointer, self).__init__(entry, self.ptr_type) + self.sspace, self.dtype = sspace, dtype + +class Immediate(Register): + """ + An Immediate is the DSL's way of storing PTX immediate values. It differs + from a Register in two respects: + + - A non-Register value can be assigned to the ``val`` property (or passed + to ``__init__``). If the value is an int or float, it will be coerced to + follow PTX's strict parsing rules for the type of the ``Immediate``; + otherwise, it'll simply be coerced to ``str`` and pasted in the PTX. + + - The ``type`` can be None, which disables all coercion and introspection. + This is practical for labels and the like. + """ + def __init__(self, entry, type, val=None): + super(Immediate, self).__init__(entry, type) + self.val = val + def _set_val(self, val): + self._ptx = self.coerce(self.type, val) + val = property(lambda s: s._ptx, _set_val) + def __repr__(self): + return object.__repr__(self)[:-1] + ': type=%s, value=%s>' % ( + self.type.name, self._ptx) @staticmethod - def initializer(*args, **kwargs): - if args and kwargs: - raise ValueError("Cannot initialize in both list and struct style") - if args: - return ['{', _softjoin(args, ','), '}'] - jkws = _softjoin([[k, ' = ', v] for k, v in kwargs.items()], ', ') - return ['{', jkws, '}'] + def coerce(type, val): + if type is None or not isinstance(val, (int, long, float)): + return val + if type.kind == 'u' and val < 0: + raise ValueError("Can't convert (< 0) val to unsigned") + # Maybe more later? + if type.kind in 'us': + return int(val) + if type.kind in 'f': + return float(val) + raise TypeError("Immediates not supported for type %s" % type.name) -class Label(object): +class Regs(object): """ - Specifies the target for a branch. + The ``entry.regs`` object to which Registers are bound. + """ + def __init__(self, entry): + self.__dict__['_entry'] = entry + self.__dict__['_named_regs'] = dict() + def __create_register_func(self, type): + def f(*args, **kwargs): + return self._entry.create_register(type, *args, **kwargs) + return f + def __getattr__(self, name): + if name in TYPES: + return self.__create_register_func(TYPES[name]) + if name in self._named_regs: + return self._named_regs[name] + raise KeyError("Unrecognized register name %s" % name) + def __setattr__(self, name, val): + if name in self._named_regs: + self._named_regs[name].val = val + else: + if isinstance(val, Register): + assert val in self._entry._regs, "Reg from nowhere!" + val.name = name + self._named_regs[name] = val + else: + raise TypeError("What Is This %s You Have Given Me" % val) - >>> label('infinite_loop') - >>> op.bra.uni('label') + +class Memory(object): """ - def __init__(self, name): + Memory objects reference device memory and and provide a convenient + shorthand for address calculations. + + The base address of a memory location may be retreived from the ``addr`` + property as a ``Pointer`` for manual address calculations. + + Somewhat more automatic address calculations can be performed using Python + bracket notation:: + + >>> r1 = o.ld(m.something[r2]) + >>> o.st(m.something[2*r2], r1) + + If the value passed in the brackets is u32, it will *not* be coerced to + u64 until being added to the base pointer. To access arrays that are more + than 4GB in size, you must coerce the input type to u64 yourself. + + Currently, all steps in an address calculation are performed for each + access, and so for inner loops manual address calculation (or simply saving + the resulting register for reuse in the next memory operation) may be more + efficient. Once the register lifetime profiler is complete, that behavior + may change. + """ + def __init__(self, entry, space, type, name): + self.entry, self.space, self.type, self.name = entry, space, type, name + @property + def addr(self): + ptr = Pointer(self.entry, self.space, self.type) + ptr._ptx = self.name + def __getitem__(self, key): + # TODO: make this multi-type-safe, perform strength reduction/precalc + ptr = Pointer(self.entry, self.space, self.type) + self.entry.add_stmt(['mad', 'lo', 'u32'], key, self.type.bytes, + self.addr, result=ptr) + return ptr + +class PtrParam(Memory): + """ + Entry parameters which contain pointers to memory locations, as created + through ``entry.add_ptr_param()``, use this type to hide the address load + from parameter space. + """ + # TODO: this assumes u32 addresses, which won't be true for long + @property + def addr(self): + ptr = Pointer(self.entry, self.space, self.type) + self.entry.add_stmt(['ld', 'param', ptr.type.name], + self.name, result=ptr) + return ptr + +class Params(object): + """ + The ``entry.params`` object to which parameters are bound. + """ + def __init__(self, entry): + # Boy this 'everything references entry` thing has gotten old + self.entry = entry + def __getattr__(self, name): + if name not in self.entry._params: + raise KeyError("Did not recognize parameter name.") + param = self.entry._params[name] + if isinstance(param, PtrParam): + return param + return self.entry.ops.ld(param.addr) + +class _DotNameHelper(object): + def __init__(self, callback, name = ()): + self.__callback = callback + self.__name = name + def __getattr__(self, name): + return _DotNameHelper(self.__callback, self.__name + (name,)) + def __call__(self, *args, **kwargs): + return self.__callback(self.__name, *args, **kwargs) + +RegUse = namedtuple('RegUse', 'src dst') +Rebinding = namedtuple('Rebinding', 'dst src') + +class Entry(object): + """ + Manager extraordinaire. + + TODO: document this. + """ + + def __init__(self, name, block_width, block_height=1, block_depth=1): self.name = name - def __str__(self): - return self.name + self.block = (block_width, block_height, block_depth) + self.threads_per_cta = block_width * block_height + self.body_seen = False + self.tail_cbs = [] + self.identifiers = set() -class _LabelFactory(object): - def __init__(self, block): - self.block = block - def __call__(self, name): - self.block.inject(name, Label(name)) - self.block.code(prefix='%s:' % name, semi=False) + self.ops = _DotNameHelper(self.add_stmt) + self._stmts = [] + self._labels = [] + self.regs = Regs(self) + self._regs = {} -class Comment(object): - """Add a single-line comment to the PTX output.""" - def __init__(self, block): - self.block = block - def __call__(self, comment): - self.block.code(op=['// ', comment], semi=False) + # Intended to be read by the ``params`` object below + self._params = {} + self.params = Params(self) -class PTXFragment(object): - """ - An object containing PTX DSL functions. The object, and all its - dependencies, will be instantiated by a PTX module. Each object will be - bound to the name given by ``shortname`` in the DSL namespace. - - Because of the instantiation weirdness, use the instmethod decorator on - instance methods that will be called from regular Python code. - """ - - # Name under which to make this code available in ptx_funcs - shortname = None - - def deps(self): - """ - Returns a list of PTXFragment types on which this object depends - for successful compilation. Circular dependencies are forbidden, - but multi-level dependencies should be fine. - """ - return [_PTXStdLib] - - def module_setup(self): - """ - PTX function to declare things at module scope. It's a PTX syntax error - to perform operations at this scope, but we don't yet validate that at - the Python level. A module will call this function on all fragments - used in that module in dependency order. - - If implemented, this function should use an @ptx_func decorator. - """ + def __enter__(self): + # May do more later pass - def entry_setup(self): - """ - PTX DSL function which will insert code at the start of an entry, for - initializing variables and stuff like that. An entry point will call - this function on all fragments used in that entry point in dependency - order. - - If implemented, this function should use an @ptx_func decorator. - """ + def __exit__(self, etype, eval, tb): + # May do more later pass - def entry_teardown(self): + def add_stmt(self, name, *operands, **kwargs): + stmt = Statement(name, operands) + idx = len(self._stmts) + for operand in stmt.operands: + operand._infer_name(2) + use = self._regs.setdefault(operand, RegUse([], [])) + use.src.append(idx) + if stmt.rtype is not None: + result = kwargs.pop('result', None) + if result: + assert result.type == stmt.rtype, "Internal type error" + else: + result = Register(self, stmt.rtype) + stmt.result = result + self._regs[result] = RegUse(src=[], dst=[idx]) + if kwargs: + raise KeyError("Unrecognized keyword arguments: %s" % kwargs) + self._stmts.append(stmt) + return stmt.result + + def add_rebinding(self, dst, src): + idx = len(self._stmts) + self._regs[dst].dst.append(idx) + if not isinstance(src, Immediate): + self._regs[src].src.append(idx) + self._stmts.append(Rebinding(dst, src)) + + def create_register(self, type, initial=None): + r = Register(self, type) + self._regs[r] = RegUse([], []) + if initial: + r.val = initial + return r + + def head(self): """ - PTX DSL function which will insert code at the end of an entry, for any - clean-up that needs to be performed. An entry point will call this - function on all fragments used in the entry point in *reverse* - dependency order (i.e. fragments which this fragment depends on will be - cleaned up after this one). - - If implemented, this function should use an @ptx_func decorator. + Top-level code segment that will be placed at the start of the entry. + Useful for initialization of memory or registers by types that do + not implement an entry point themselves. """ - pass + # This may do more later + return self - def finalize_code(self): + def body(self): """ - Called after running all PTX DSL functions, but before code generation, - to allow fragments which postponed variable evaluation (e.g. using - ``DelayVar``) to fill in the resulting values. Most fragments should - not use this. - - If implemented, this function *may* use an @ptx_func decorator to - access the global DSL scope, but pretty please don't emit any code - while you're in there. + Top-level code segment representing the body of the entry point. """ - pass + # This may do more later + assert not self.body_seen, "Only one body per entry allowed." + self.body_seen = True + return self - def tests(self): + def tail_callback(self, cb, *args, **kwargs): """ - Returns a list of PTXTest types which will test this fragment. + Registers a tail callback function. After the body segment is complete, + the tail callbacks will be called in reverse, such that each head/tail + pair nests in dependency order. + + Any arguments to this function will be passed to the callback. """ - return [] + self.tail_cbs.append((cb, args, kwargs)) - def call_setup(self, ctx): + def add_param(self, ptype, name): """ - Do stuff on the host to prepare the device for execution. 'ctx' is a - LaunchContext or similar. This will get called (in dependency order, of - course) before each function invocation. + Adds a parameter to this entry. ``type`` and ``name`` are strings. """ - # I haven't found a good way to get outside context in for this method. - # As a result, this is usually just a check to see if some other - # necessary method has been called before trying to launch. - pass + if ptype not in TYPES: + raise TypeError("Unrecognized PTX type name.") + self._params[name] = Memory(self, 'param', TYPES[ptype], name) - def call_teardown(self, ctx): + def add_ptr_param(self, name, mtype): """ - As with ``call_setup``, but after a call and in reverse order. + Adds a parameter to this entry which points to a location in global + memory. The resulting property of ``entry.params`` will be a + ``PtrParam`` for convenient access. + + ``name`` is the param name, and ``mtype`` is the base type of the + memory location being pointed to. The actual type of the pointer will + be chosen based on the amount of addressable memory on the device. """ - # Exceptions raised here will propagate from the invocation in Python, - # so this is a good place to do error checking. - pass + if mtype not in TYPES: + raise TypeError("Unrecognized PTX type name.") + # TODO: add pointer size heuristic + self._params[name] = PtrParam(self, 'global', TYPES[mtype], name) -def instmethod(func): - """ - Wrapper to allow instances to be retrieved from an active context. Use it - on methods which depend on state created during a compilation phase, but - are intended to be called from normal Python code. - """ - def wrap(cls, ctx, *args, **kwargs): - inst = ctx.ptx.instances[cls] - return func(inst, ctx, *args, **kwargs) - return classmethod(wrap) - -class PTXEntryPoint(PTXFragment): - # Device code entry name - entry_name = "" - # List of (type, name) pairs for entry params, e.g. [('u32', 'thing')] - entry_params = [] - maxnreg = None - - def entry(self): + def finalize(self): """ - PTX DSL function that comprises the body of the PTX statement. - - Must be implemented and decorated with ptx_func. + This method runs the tail callbacks and performs any introspection + necessary prior to emitting PTX. """ - raise NotImplementedError + assert self.tail_cbs is not None, "Cannot finalize more than once!" + for cb, args, kwargs in reversed(self.tail_cbs): + cb(*args, **kwargs) + self.tail_cbs = None - def _call(self, ctx, func, *args, **kwargs): - """ - Override this if you need to change how a function is called. - """ - # TODO: global debugging / verbosity - print "\nInvoking PTX function '%s' on device" % self.entry_name - kwargs.setdefault('block', ctx.block) - kwargs.setdefault('grid', ctx.grid) - dtime = func(time_kernel=True, *args, **kwargs) - print "'%s' completed in %gs" % (self.entry_name, dtime) + # This loop verifies rebinding of floating registers to named ones. + # If all of the conditions below are met, the src register's name will + # be allowed to match the dst register; otherwise, the src's value + # will be copied to the dst's with a ``mov`` instruction + for idx, stmt in enumerate(self._stmts): + if not isinstance(stmt, Rebinding): continue + dst, src = stmt + # src must be floating reg, not immediate or bound reg + # Examples: + # r.a = r.u32(4) + # b = r.u32(r.a) + move = isinstance(src, Immediate) or src.name is not None + # dst cannot be used between src's originating expression and + # the rebinding itself + # Example 1: + # r.a, r.b = r.u32(1), r.u32(1) + # x = o.add(r.a, r.b) + # r.b = o.add(r.a, x) + # r.a = x + # Example 2: + # r.a, r.b = r.u32(1), r.u32(1) + # label('start') + # x = o.add(r.a, r.b) + # y = o.add(r.a, x) + # r.a = x + # r.b = y + # bra.uni('start') + # TODO: incorporate branch tracking + if not move: + for oidx in (self._regs[dst].src + self._regs[dst].dst): + if oidx > self._regs[src].dst[0] and oidx < idx: + move = True + if move: + src.rebound_to = None + stmt = Statement(('mov',), (src,)) + stmt.result = dst + self._stmts[idx] = stmt - @instmethod - def call(self, ctx, *args, **kwargs): - """ - Calls the entry point on the device, performing any setup and teardown - needed. - """ - ctx.call_setup(self) - func = ctx.mod.get_function(self.entry_name) - try: - self._call(ctx, func, *args, **kwargs) - finally: - res = ctx.call_teardown(self) - return res - -class PTXTestFailure(Exception): pass - -class PTXTest(PTXEntryPoint): - """PTXTests are semantically equivalent to PTXEntryPoints, but they differ - slightly in the way they are invoked: - - * The active context will be synchronized before each call, - * call_teardown() should raise ``PTXTestFailure`` if a test failed. - This exception will be caught and cleanup will be completed. - """ - pass - -class _PTXStdLib(PTXFragment): - shortname = "std" - - def __init__(self, block): - # Only module that gets the privilege of seeing 'block' directly. - self.block = block - self.asserts = ["Success"] - - def deps(self): - return [] - - @ptx_func - def module_setup(self): - # TODO: make this modular, maybe? of course, we'd have to support - # multiple devices first, which we definitely do not yet do - self.block.code(prefix='.version 2.1', semi=False) - self.block.code(prefix='.target sm_21', semi=False) - mem.global_.u32('g_std_exit_err', ctx.nthreads) - - @ptx_func - def get_gtid(self, dst): - """ - Get the global thread ID (the position of this thread in a grid of - blocks of threads). This assumes that both grid and block are - one-dimensional! (This is always true for cuburn.) - """ - with block("Load GTID into %s" % str(dst)): - reg.u32('cta ncta tid') - op.mov.u32(cta, '%ctaid.x') - op.mov.u32(ncta, '%ntid.x') - op.mov.u32(tid, '%tid.x') - op.mad.lo.u32(dst, cta, ncta, tid) - - @ptx_func - def store_per_thread(self, *args): - """For each pair of arguments ``addr, val``, write ``val`` to the - address given by ``addr+sizeof(val)*gtid``. If ``val`` is not a - register, size will be taken from ``addr``; if ``addr`` is not a Mem - instance, size defaults to 4.""" - with block("Per-thread storing values"): - reg.u32('spt_base spt_offset') - self.get_gtid(spt_offset) - for i in range(0, len(args), 2): - base, val = args[i], args[i+1] - width = 4 - if isinstance(base, Mem): - width = int(base.type[-1][-2:])/8 - if isinstance(val, Reg): - width = int(val.type[-2:])/8 - op.mov.u32(spt_base, base) - op.mad.lo.u32(spt_base, spt_offset, width, spt_base) - if isinstance(val, float): - # Turn a constant float into the big-endian PTX binary f32 - # representation, 0fXXXXXXXX (where XX is hex byte) - val = '0f%x%x%x%x' % reversed(map(ord, - struct.pack('f', val))) - op._call(['st', 'b%d' % (width*8)], addr(spt_base), val) - - @ptx_func - def set_is_first_thread(self, p_dst): - with block("Set %s if this is thread 0 in the CTA" % p_dst.name): - reg.u32('tid') - op.mov.u32(tid, '%tid.x') - op.setp.eq.u32(p_dst, tid, 0) - - def not_(self, pred): - return ['!', pred] - - @ptx_func - def asrt(self, msg, o=None, a=None, b=None, p=None, notp=None, - ret=False, ign=False, lvl=1): - """ - Device assertion. - - Without arguments, a thread will log the error code associated with - ``msg`` and issue a trap instruction, which will cause the device to - terminate execution in all threads immediately. Any of the options - below modify that behavior, as described. - - ``o``, ``a`` and ``b``, when set together, will be used to create a - ``setp`` instruction to test a condition. They're the first three - arguments, to make usage a bit more natural: - - >>> std.asrt('lt.u32', val, 0) - - This would generate the instruction ``setp.lt.u32

, val, 0;`` - (

is created by this function). The thread would only store the - error code and exit if the condition were *false*. - - ``p`` is a predicate value; the store and trap will happen if it is - *not* set (same sense as ``o`` and Python's assert). ``notp`` is the - reverse. - - Only one of ``o``, ``ifp``, or ``ifnotp`` can be set per call. - - ``ret`` causes the assert to issue a ``ret;`` instruction in place of - the trap. This causes the current thread to terminate, but does not - cause the other threads to do so. Be cautious, as barriers can cause a - kernel to hang using this instruction. - - ``ign`` causes the error code to be stored, but does not terminate - thread execution ("ignores" the error). This is useful to identify the - location of all threads in case of an abnormal termination caused by - another thread, and is used to set up the entry-wide "early - termination" error. ``ign`` overrides ``ret``. - - This code calculates the gtid unconditionally, and so can be relatively - expensive to insert into a tight loop. As a result, assert - statements will only be added if the debug value ``assert_level`` is - at least as large as the ``lvl`` argument. - """ - # TODO: debug level checking - if np.sum(map(bool, (o, p, notp))) > 1: - raise ValueError("Can only use one of o, ifp, ifnotp.") - if msg not in self.asserts: - self.asserts.append(msg) - err_code = self.asserts.index(msg) - with block("Assertion: " + msg): - reg.u32('asrt_base asrt_off') - op.mov.u32(asrt_base, g_std_exit_err) - self.get_gtid(asrt_off) - op.mad.lo.u32(asrt_base, asrt_off, 4, asrt_base) - realp = None - if o: - realp = self.not_(reg.pred('p_asrt_fail')) - if a is None or b is None: - raise ValueError("Must specify ``a`` and ``b`` with ``o``.") - op._call(['setp.'+o], p_asrt_fail, a, b) - if p: - realp = self.not_(p) - if notp: - realp = notp - op.st.global_.u32(addr(asrt_base), err_code, ifp=realp) - if not ign: - if ret: - op.ret(ifp=realp) + # Identify all uses of registers by name in the program + bound = dict([(t, set()) for t in TYPES.values()]) + free = dict([(t, {}) for t in TYPES.values()]) + for stmt in self._stmts: + if isinstance(stmt, Rebinding): + regs = [stmt.src, stmt.dst] + else: + regs = filter(lambda r: r and not isinstance(r, Immediate), + (stmt.result,) + stmt.operands) + for reg in regs: + if reg.name: + bound[reg.type].add(reg.name) else: - op.trap(ifp=realp) + rl = free[reg.type].setdefault(reg.inferred_name, []) + if reg not in rl: + rl.append(reg) - @ptx_func - def entry_setup(self): - self.asrt("Unexpected thread exit", ign=True, lvl=0) + # Store the data required for register declarations + self.bound = bound + self.temporary = {} - @ptx_func - def entry_teardown(self): - self.asrt(self.asserts[0], ret=True, lvl=0) + # Generate names for all unbound registers + # TODO: include memory, label, instr identifiers in this list + identifiers = set() + map(identifiers.update, bound.values()) + used_bases = set([i.rstrip('1234567890') for i in identifiers]) + for t, inames in free.items(): + for ibase, regs in inames.items(): + if ibase is None: + ibase = t.name + '_' + while ibase in used_bases: + ibase = ibase + '_' + trl = self.temporary.setdefault(t, []) + trl.append('%s<%d>' % (ibase, len(regs))) + for i, reg in enumerate(regs): + reg.name = ibase + str(i) - def call_teardown(self, ctx): - """ - This function raises an exception if all cleanup code wasn't called on - the device. To suppress this - for instance, to inspect data from a - partially-executed thread - do + def format_source(self, formatter): + assert self.tail_cbs is None, "Must finalize entry before formatting" + params = [v for k, v in sorted(self._params.items())] + formatter.entry_start(self.name, params, reqntid=self.block) + [formatter.regs(t, r) for t, r in sorted(self.bound.items()) if r] + formatter.comment("Temporary registers") + [formatter.regs(t, r) for t, r in sorted(self.temporary.items()) if r] + formatter.blank() - >>> std.asrt(std.asserts[0], ign=True, lvl=0) + for stmt in self._stmts: + if isinstance(stmt, Statement): + stmt.ptx_line = formatter.stmt(stmt) + formatter.entry_end() - at the start of your entry. Yes, it's a hacky solution. - """ - dp, l = ctx.mod.get_global('g_std_exit_err') - errs = cuda.from_device(dp, ctx.nthreads, np.uint32) - if np.sum(errs) != 0: - print "Some threads terminated unsuccessfully." - for i, msg in enumerate(self.asserts): - count = sum(np.equal(errs, i)) - if count: - print '%6d said "%s".' % (count, msg) - print - raise EnvironmentError("Abnormal thread termination") - - def to_inject(self): - # Set up the initial namespace - return dict( - _block=self.block, - block=Block(self.block), - op=Op(self.block), - reg=_RegFactory(self.block), - mem=_MemFactory(self.block), - addr=Mem.addr, - vec=Mem.vec, - label=_LabelFactory(self.block), - comment=Comment(self.block)) - -class PTXModule(object): - """ - Assembles PTX fragments into a module. The following properties are - available: - - `instances`: Mapping of type to instance for the PTXFragments used in - the creation of this PTXModule. - `entries`: List of PTXEntry types in this module, including any tests. - `tests`: List of PTXTest types in this module. - `source`: PTX source code for this module. - """ - max_compiles = 10 - - def __init__(self, entries, inject={}, build_tests=False, formatter=None): - """ - Construct a PTXModule. - - `entries`: List of PTXEntry types to include in this module. - `inject`: Dict of items to inject into the DSL namespace. - `build_tests`: If true, build tests into the module. - `formatter`: PTXFormatter instance, or None to use defaults. - """ - block = _Block() - insts, tests, all_deps, entry_deps = ( - self.deptrace(block, entries, build_tests)) - self.instances = insts - self.entry_deps = entry_deps - self.tests = tests - - inject = dict(inject) - inject.update(insts[_PTXStdLib].to_inject()) - self._safeupdate(inject, 'module', self) - for inst in all_deps: - if inst.shortname: - self._safeupdate(inject, inst.shortname, inst) - [block.inject(k, v) for k, v in inject.items()] - - self.__needs_recompilation = True - self.compiles = 0 - while self.__needs_recompilation: - self.compiles += 1 - self.__needs_recompilation = False - self.assemble(block, all_deps, entry_deps) - self.instances.pop(_PTXStdLib) - - if not formatter: - formatter = PTXFormatter() - self.source = formatter.format(block.outer_ctx.code) - self.entries = list(set(entries + tests)) - - def deporder(self, unsorted_instances, instance_map): - """ - Do a DFS on PTXFragment dependencies, and return an ordered list of - instances where no fragment depends on any before it in the list. - - `unsorted_instances` is the list of instances to sort. - `instance_map` is a dict of types to instances. - """ - seen = {} - def rec(inst): - if inst in seen: return seen[inst] - if inst is None: return 0 - deps = filter(lambda d: d is not inst, - map(instance_map.get, inst.deps())) - return seen.setdefault(inst, 1+max([0]+map(rec, deps))) - map(rec, unsorted_instances) - return sorted(unsorted_instances, key=seen.get) - - def _safeupdate(self, dst, k, v): - if k in dst: raise KeyError("Duplicate key %s" % k) - dst[k] = v - - def deptrace(self, block, entries, build_tests): - instances = {_PTXStdLib: _PTXStdLib(block)} - unvisited_entries = list(entries) - tests = set() - entry_deps = {} - - # For each PTXEntry or PTXTest, use a BFS to recursively find and - # instantiate all fragments that are dependencies. If tests are - # discovered, add those to the list of entries. - while unvisited_entries: - ent = unvisited_entries.pop(0) - seen, unvisited = set(), [ent] - while unvisited: - frag = unvisited.pop(0) - seen.add(frag) - # setdefault doesn't work because of _PTXStdLib - if frag not in instances: - inst = frag() - instances[frag] = inst - else: - inst = instances[frag] - for dep in inst.deps(): - if dep not in seen: - unvisited.append(dep) - if build_tests: - for test in inst.tests(): - if test not in tests: - tests.add(test) - if test not in instances: - unvisited_entries.append(test) - # For this entry, store insts of all dependencies in order. - entry_deps[ent] = self.deporder(map(instances.get, seen), - instances) - # Find the order for all dependencies in the program. - all_deps = self.deporder(instances.values(), instances) - - return instances, sorted(tests, key=str), all_deps, entry_deps - - def assemble(self, block, all_deps, entry_deps): - # Rebind to local namespace to allow proper retrieval - _block = block - - for inst in all_deps: - inst.module_setup() - - for ent, insts in entry_deps.items(): - # This is kind of hackish compared to everything else - params = [Reg('.param.' + str(type), name) - for (type, name) in ent.entry_params] - _block.code(op='.entry %s' % ent.entry_name, semi=False, - vars=['(', ', '.join(['%s %s' % (r.type, r.name) - for r in params]), ')']) - if ent.maxnreg: - _block.code(op='.maxnreg %d' % ent.maxnreg, semi=False) - with Block(_block): - [_block.inject(r.name, r) for r in params] - for dep in insts: - dep.entry_setup() - self.instances[ent].entry() - for dep in reversed(insts): - dep.entry_teardown() - - for inst in all_deps: - inst.finalize_code() - - def set_needs_recompilation(self): - if not self.__needs_recompilation: - if self.compiles >= self.max_compiles: - raise ValueError("Too many recompiles scheduled!") - self.__needs_recompilation = True - - def print_source(self): - if not hasattr(self, 'source'): - raise ValueError("Not assembled yet!") - print '\n'.join(["%03d %s" % (i+1, l) for (i, l) in - enumerate(self.source.split('\n'))]) - -def _flatten(val): - if isinstance(val, (list, tuple)): - return ''.join(map(_flatten, val)) - return str(val) class PTXFormatter(object): - """ - Formats PTXStmt items into beautiful code. Well, the beautiful part is - postponed for now. - """ - def __init__(self, indent_amt=4, oplen_max=20, varlen_max=12): - self.idamt, self.opm, self.vm = indent_amt, oplen_max, varlen_max - def format(self, code): - out = [] - indent = 0 - idx = 0 - while idx < len(code): - opl, vl = 0, 0 - flat = [] - while idx < len(code): - pfx, op, vars, semi, indent_change = code[idx] - idx += 1 - if indent_change: break - pfx, op = _flatten(pfx), _flatten(op) - vars = map(_flatten, vars) - if len(op) <= self.opm: - opl = max(opl, len(op)+2) - for var in vars: - if len(var) <= self.vm: - vl = max(vl, len(var)+1) - flat.append((pfx, op, vars, semi)) - for pfx, op, vars, semi in flat: - if pfx: - line = ('%%-%ds ' % (indent-1)) % pfx - else: - line = ' '*indent - line = ('%%-%ds ' % (indent+opl)) % (line+op) - for i, var in enumerate(vars): - line = ('%%-%ds ' % (indent+opl+vl*(i+1))) % (line+var) - if semi: - line = line.rstrip() + ';' - out.append(line) - indent = max(0, indent + self.idamt * indent_change) - return '\n'.join(out) + def __init__(self, ptxver=PTX_VERSION, target='sm_21'): + self.indent_level = 0 + self.lines = ['.version %d.%d' % ptxver, '.target %s' % target] + + def blank(self): + self.lines.append('') + + def comment(self, text): + self.lines.append(' ' * self.indent_level + '// ' + text) + + def regs(self, type, names): + # TODO: indenting, length limits, etc. + self.lines.append(' ' * self.indent_level + '.reg .%s ' % (type.name) + + ', '.join(sorted(names)) + ';') + + def stmt(self, stmt): + res = ('%s, ' % stmt.result.get_name()) if stmt.result else '' + args = [o.get_name() for o in stmt.operands] + # Wrap the arg in brackets if needed (no good place to put this) + if stmt.fullname[0] in ('ld ldu st prefetch prefetchu isspacep ' + 'atom red'.split()): + args[0] = '[%s]' % args[0] + + self.lines.append(''.join([' ' * self.indent_level, + '%-12s ' % '.'.join(stmt.fullname), res, ', '.join(args), ';'])) + return len(self.lines) + + def entry_start(self, name, params, **directives): + """ + Define the start of an entry point. ``name`` and ``params`` should be + obvious, ``directives`` is a dictionary of performance tuning directive + strings. As a special case, if a ``directive`` value is a tuple, it + will be converted to a comma-separated string. + """ + for k, v in directives.items(): + if isinstance(v, tuple): + directives[k] = ','.join(map(str, v)) + dstr = ' '.join(['.%s %s' % i for i in directives.items()]) + # TODO: support full param options like alignment and array decls + # (base the param type off a memory type) + pstrs = ['.param .%s %s' % (p.type.name, p.name) for p in params] + pstr = '(%s)' % ', '.join(pstrs) + self.lines.append(' '.join(['.entry', name, pstr, dstr])) + self.lines.append('{') + self.indent_level += 4 + + def entry_end(self): + self.indent_level += 4 + self.lines.append('}') + + def get_source(self): + return '\n'.join(self.lines) _TExp = namedtuple('_TExp', 'type exprlist') _DataCell = namedtuple('_DataCell', 'offset size texp') -class DataStream(PTXFragment): +class DataStream(object): """ Simple interface between Python and PTX, designed to create and tightly pack control structs. @@ -1058,7 +705,6 @@ class DataStream(PTXFragment): Inside DSL functions, you can retrieve arbitrary Python expressions from the data stream. - >>> @ptx_func >>> def example_func(): >>> reg.u32('reg1 reg2 regA') >>> op.mov.u32(regA, some_device_allocation_base_address) @@ -1076,7 +722,6 @@ class DataStream(PTXFragment): access times when taking device caching into account. This also implies that the evaluated expressions should not modify any state. - >>> @ptx_func >>> def example_func_2(): >>> reg.u32('reg1 reg2') >>> reg.f32('regf') @@ -1092,7 +737,6 @@ class DataStream(PTXFragment): fancy things like multiplying two DelayVars aren't implemented yet. >>> class Whatever(PTXFragment): - >>> @ptx_func >>> def module_setup(self): >>> mem.global_.u32('ex_streams', ex.stream_size*1000) """ @@ -1155,7 +799,6 @@ class DataStream(PTXFragment): fsize *= 2 return offset - @ptx_func def _stream_get_internal(self, areg, dregs, exprs, ifp, ifnotp): size, type = self._get_type(dregs) vsize = size * len(dregs) @@ -1171,17 +814,14 @@ class DataStream(PTXFragment): dregs = vec(*dregs) op._call(opname, dregs, addr(areg, offset), ifp=ifp, ifnotp=ifnotp) - @ptx_func def get(self, areg, dreg, expr, ifp=None, ifnotp=None): self._stream_get_internal(areg, [dreg], [expr], ifp, ifnotp) - @ptx_func def get_v2(self, areg, dreg1, expr1, dreg2, expr2, ifp=None, ifnotp=None): self._stream_get_internal(areg, [dreg1, dreg2], [expr1, expr2], ifp, ifnotp) # The interleaved signature makes calls easier to read - @ptx_func def get_v4(self, areg, d1, e1, d2, e2, d3, e3, d4, e4, ifp=None, ifnotp=None): self._stream_get_internal(areg, [d1, d2, d3, d4], [e1, e2, e3, e4], @@ -1202,7 +842,6 @@ class DataStream(PTXFragment): print "Finalized stream:" self._print_format() - @instmethod def pack(self, ctx, _out_file_ = None, **kwargs): """ Evaluates all statements in the context of **kwargs. Take this code, @@ -1227,7 +866,6 @@ class DataStream(PTXFragment): cls.pack_into(out, kwargs) return out.read() - @instmethod def pack_into(self, ctx, outfile, **kwargs): """ Like pack(), but write data to a file-like object at the file's current @@ -1255,9 +893,7 @@ class DataStream(PTXFragment): cell.texp.exprlist[0]) for exp in cell.texp.exprlist[1:]: print '%11s %s' % ('', exp) - print_format = instmethod(_print_format) - @instmethod def print_record(self, ctx, stream, limit=None): for i in range(0, len(stream), self._size): for cell in self.cells: diff --git a/main.py b/main.py index f4284c7..0040e84 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,8 @@ from ctypes import * import numpy as np +np.set_printoptions(precision=5, edgeitems=20) + from cuburn.device_code import * from cuburn.cuda import LaunchContext from fr0stlib.pyflam3 import *