Code builder, RNG working

This commit is contained in:
Steven Robertson 2010-08-28 00:28:00 -04:00
parent a23ebdcf5f
commit 907cbb273f

397
main.py
View File

@ -11,6 +11,7 @@
import os
import sys
import time
import ctypes
import struct
@ -32,57 +33,152 @@ import numpy as np
from fr0stlib import pyflam3
# PTX header and functions used for debugging.
prelude = """
.version 2.0
.target sm_20
def ppr_ptx(src):
# TODO: Add variable realignment
indent = 0
out = []
for line in [l.strip() for l in src.split('\n')]:
if not line:
continue
if len(line.split()) == 1 and line.endswith(':'):
out.append(line)
continue
if '}' in line and '{' not in line:
indent -= 1
out.append(' ' * (indent * 4) + line)
if '{' in line and '}' not in line:
indent += 1
return '\n'.join(out)
.func (.reg .u32 $ret) get_gtid ()
{
.reg .u16 tmp;
.reg .u32 cta, ncta, tid, gtid;
mov.u16 tmp, %ctaid.x;
cvt.u32.u16 cta, tmp;
mov.u16 tmp, %ntid.x;
cvt.u32.u16 ncta, tmp;
mul24.lo.u32 gtid, cta, ncta;
mov.u16 tmp, %tid.x;
cvt.u32.u16 tid, tmp;
add.u32 gtid, gtid, tid;
mov.b32 $ret, gtid;
ret;
}
.entry write_to_buffer ( .param .u32 bufbase )
{
.reg .u32 base, gtid, off;
ld.param.u32 base, [bufbase];
call.uni (off), get_gtid, ();
mad24.lo.u32 base, off, 4, base;
st.volatile.global.b32 [base], off;
}
"""
def multisub(tmpl, subs):
while '{{' in tmpl:
tmpl = tempita.Template(tmpl).substitute(subs)
return tmpl
class CUGenome(pyflam3.Genome):
def _render(self, frame, trans):
obuf = (ctypes.c_ubyte * ((3+trans)*self.width*self.height))()
stats = pyflam3.RenderStats()
pyflam3.flam3_render(ctypes.byref(frame), obuf, pyflam3.flam3_field_both,
pyflam3.flam3_render(ctypes.byref(frame), obuf,
pyflam3.flam3_field_both,
trans+3, trans, ctypes.byref(stats))
return obuf, stats, frame
class LaunchContext(self):
def __init__(self, seed=None):
self.block, self.grid, self.threads = None, None, None
self.stream = cuda.Stream()
self.rand = mtrand.RandomState(seed)
class LaunchContext(object):
"""
Context collecting the information needed to create, run, and gather the
results of a device computation.
def set_size(self, block, grid):
self.block, self.grid = block, grid
self.threads = reduce(lambda a, b: a*b, self.block + self.grid)
To create the fastest device code across multiple device families, this
context may decide to iteratively refine the final PTX by regenerating
and recompiling it several times to optimize certain parameters of the
launch, such as the distribution of threads throughout the device.
The properties of this device which are tuned are listed below. Any PTX
fragments which use this information must emit valid PTX for any state
given below, but the PTX is only required to actually run with the final,
fixed values of all tuned parameters below.
`block`: 3-tuple of (x,y,z); dimensions of each CTA.
`grid`: 2-tuple of (x,y); dimensions of the grid of CTAs.
`threads`: Number of active threads on device as a whole.
`mod`: Final compiled module. Unavailable during assembly.
"""
def __init__(self, block=(1,1,1), grid=(1,1), seed=None, tests=False):
self.block, self.grid, self.tests = block, grid, tests
self.stream = cuda.Stream()
self.rand = np.random.mtrand.RandomState(seed)
@property
def threads(self):
return reduce(lambda a, b: a*b, self.block + self.grid)
def _deporder(self, unsorted_instances, instance_map):
# Do a DFS on the mapping of PTXFragment types to instances, returning
# a list of instances ordered such that nothing depends on anything
# before it in the list
seen = {}
def rec(inst):
if inst in seen: return seen[inst]
deps = filter(lambda d: d is not inst, map(instance_map.get,
callable(inst.deps) and inst.deps(self) or inst.deps))
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
for key, val in src.items():
if key in dst:
raise KeyError("Duplicate key %s" % key)
dst[key] = val
def assemble(self, entries):
# Get a property, dealing with the callable-or-data thing
def pget(prop):
if callable(prop): return prop(self)
return prop
instances = {}
entries_unvisited = list(entries)
tests = set()
parsed_entries = []
while entries_unvisited:
ent = entries_unvisited.pop(0)
seen, unvisited = set(), [ent]
while unvisited:
frag = unvisited.pop(0)
seen.add(frag)
inst = instances.setdefault(frag, frag())
for dep in pget(inst.deps):
if dep not in seen:
unvisited.append(dep)
tmpl_namespace = {'ctx': self}
entry_start, entry_end = [], []
for inst in self._deporder(map(instances.get, seen), instances):
self._safeupdate(tmpl_namespace, pget(inst.subs))
entry_start.append(pget(inst.entry_start))
entry_end.append(pget(inst.entry_end))
entry_start_tmpl = '\n'.join(filter(None, entry_start))
entry_end_tmpl = '\n'.join(filter(None, reversed(entry_end)))
name, args, body = pget(instances[ent].entry)
tmpl_namespace.update({'_entry_name_': name, '_entry_args_': args,
'_entry_body_': body, '_entry_start_': entry_start_tmpl,
'_entry_end_': entry_end_tmpl})
entry_tmpl = """
.entry {{ _entry_name_ }} ({{ _entry_args_ }})
{
{{ _entry_start_ }}
{{ _entry_body_ }}
{{ _entry_end_ }}
}
"""
parsed_entries.append(multisub(entry_tmpl, tmpl_namespace))
prelude = []
tmpl_namespace = {'ctx': self}
for inst in self._deporder(instances.values(), instances):
prelude.append(pget(inst.prelude))
self._safeupdate(tmpl_namespace, pget(inst.subs))
tmpl_namespace['_prelude_'] = '\n'.join(filter(None, prelude))
tmpl_namespace['_entries_'] = '\n\n'.join(parsed_entries)
tmpl = "{{ _prelude_ }}\n\n{{ _entries_ }}\n"
return instances, multisub(tmpl, tmpl_namespace)
def compile(self, entries):
# For now, do no optimization.
self.instances, self.src = self.assemble(entries)
self.src = ppr_ptx(self.src)
try:
self.mod = cuda.module_from_buffer(self.src)
except (cuda.CompileError, cuda.RuntimeError), e:
print "Aww, dang, compile error. Here's the source:"
print '\n'.join(["%03d %s" % (i+1, l)
for (i, l) in enumerate(self.src.split('\n'))])
raise e
class PTXFragment(object):
"""
@ -99,14 +195,19 @@ class PTXFragment(object):
Template code will be processed recursively until all "{{" instances have
been replaced, using the same namespace each time.
Note that any method which does not depend on 'ctx' can be replaced with
an instance of the appropriate return type. So, for example, the 'deps'
property can be a flat list instead of a function.
"""
def deps(self, ctx):
"""
Returns a list of PTXFragment objects on which this object depends
for successful compilation. Circular dependencies are forbidden.
for successful compilation. Circular dependencies are forbidden,
but multi-level dependencies should be fine.
"""
return []
return [DeviceHelpers]
def subs(self, ctx):
"""
@ -124,33 +225,40 @@ class PTXFragment(object):
"""
return ""
def entryPrelude(self, ctx):
def entry_start(self, ctx):
"""
Returns a template string that should be inserted at the top of any
entry point which depends on this method. The entry prelude of all
entry point which depends on this method. The entry starts of all
deps will be inserted above this entry prelude.
"""
return ""
def setUp(self, ctx):
def entry_end(self, ctx):
"""
As above, but at the end of the calling function, and with the order
reversed (all dependencies will be inserted after this).
"""
return ""
def set_up(self, ctx):
"""
Do start-of-stream initialization, such as copying data to the device.
"""
pass
def test(self, ctx):
"""
Perform device tests. Returns True on success, False on failure,
or raises an exception.
"""
return True
# A list of PTXTest classes which will test this fragment
tests = []
class PTXEntryPoint(PTXFragment):
# Human-readable entry point name
name = ""
def entry(self, ctx):
"""
Returns a template string corresponding to a PTX entry point.
Returns a 3-tuple of (name, args, body), which will be assembled into
a function.
"""
pass
raise NotImplementedError
def call(self, ctx):
"""
@ -159,71 +267,190 @@ class PTXEntryPoint(PTXFragment):
"""
pass
class PTXTest(PTXEntryPoint):
"""PTXTests are semantically equivalent to PTXEntryPoints, but they
differ slightly in use. In particular:
* The "name" property should describe the test being performed,
* ctx.stream will be synchronized before 'call' is run, and should be
synchronized afterwards (i.e. sync it yourself or don't use it),
* call() should return True to indicate that a test passed, or
False (or raise an exception) if it failed.
"""
pass
class DeviceHelpers(PTXFragment):
"""This one's included by default, no need to depend on it"""
prelude = ".version 2.1\n.target sm_20\n\n"
def _get_gtid(self, dst):
return "{\n// Load GTID into " + dst + """
.reg .u16 tmp;
.reg .u32 cta, ncta, tid, gtid;
mov.u16 tmp, %ctaid.x;
cvt.u32.u16 cta, tmp;
mov.u16 tmp, %ntid.x;
cvt.u32.u16 ncta, tmp;
mul.lo.u32 gtid, cta, ncta;
mov.u16 tmp, %tid.x;
cvt.u32.u16 tid, tmp;
add.u32 gtid, gtid, tid;
mov.b32 """ + dst + ", gtid;\n}"
def subs(self, ctx):
return {
'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32',
'get_gtid': self._get_gtid
}
class MWCRandGen(PTXFragment):
_prelude = """
.const {{PTRT}} mwc_rng_mults_p;
.const {{PTRT}} mwc_rng_seeds_p;
"""
class MWCRNG(PTXFragment):
def __init__(self):
if not os.path.isfile(os.path.join(os.path.dirname(__FILE__),
'primes.bin')):
if not os.path.isfile('primes.bin'):
raise EnvironmentError('primes.bin not found')
def prelude(self):
return self._prelude
prelude = """
.global .u32 mwc_rng_mults[{{ctx.threads}}];
.global .u64 mwc_rng_state[{{ctx.threads}}];"""
def setUp(self, ctx):
def _next_b32(self, dreg):
# TODO: make sure PTX optimizes away superfluous move instrs
return """
{
// MWC next b32
.reg .u64 mwc_out;
cvt.u64.u32 mwc_out, mwc_car;
mad.wide.u32 mwc_out, mwc_st, mwc_mult, mwc_out;
mov.b64 {mwc_st, mwc_car}, mwc_out;
mov.u32 %s, mwc_st;
}
""" % dreg
def subs(self, ctx):
return {'mwc_next_b32': self._next_b32}
entry_start = """
.reg .u32 mwc_st, mwc_mult, mwc_car;
{
// MWC load multipliers and RNG states
.reg .u32 mwc_off, mwc_addr;
{{ get_gtid('mwc_off') }}
mov.u32 mwc_addr, mwc_rng_mults;
mad.lo.u32 mwc_addr, mwc_off, 4, mwc_addr;
ld.global.u32 mwc_mult, [mwc_addr];
mov.u32 mwc_addr, mwc_rng_state;
mad.lo.u32 mwc_addr, mwc_off, 8, mwc_addr;
ld.global.v2.u32 {mwc_st, mwc_car}, [mwc_addr];
}
"""
entry_end = """
{
// MWC save states
.reg .u32 mwc_addr, mwc_off;
{{ get_gtid('mwc_off') }}
mov.u32 mwc_addr, mwc_rng_state;
mad.lo.u32 mwc_addr, mwc_off, 8, mwc_addr;
st.global.v2.u32 [mwc_addr], {mwc_st, mwc_car};
}
"""
def set_up(self, ctx):
# Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B')
mults = np.frombuffer(primefp.read(), dtype=dt)
# Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise.
ctx.rand.shuffle(mults[:ctx.threads*4])
# TODO: fix mutability, enable shuffle here
#ctx.rand.shuffle(mults[:ctx.threads*4])
# Copy multipliers and seeds to the device
devmp, devml = ctx.mod.get_global('mwc_rng_mults')
cuda.memcpy_htod_async(devmp, mults.tostring()[:devml], ctx.stream)
devsp, devsl = ctx.mod.get_global('mwc_rng_seeds')
cuda.memcpy_htod_async(devsp, ctx.rand.bytes(devsl), ctx.stream)
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
# TODO: get async to work
#cuda.memcpy_htod_async(multdp, mults.tostring()[:multl], ctx.stream)
cuda.memcpy_htod(multdp, mults.tostring()[:multl])
statedp, statel = ctx.mod.get_global('mwc_rng_state')
#cuda.memcpy_htod_async(statedp, ctx.rand.bytes(statel), ctx.stream)
cuda.memcpy_htod(statedp, ctx.rand.bytes(statel))
def _next_b32(self, dreg):
return """
mul.wide.u32 mwc_rng_
mul.wide.u32
def tests(self, ctx):
return [MWCRNGTest]
class MWCRNGTest(PTXTest):
name = "MWC RNG sum-of-threads test"
deps = [MWCRNG]
rounds = 200
prelude = ".global .u64 mwc_rng_test_sums[{{ctx.threads}}];"
def templates(self, ctx):
return {'mwc_next_b32', self._next_b32}
def entry(self, ctx):
return ('MWC_RNG_test', '', """
.reg .u64 sum, addl;
.reg .u32 addend;
mov.u64 sum, 0;
{{for round in range(%d)}}
{{ mwc_next_b32('addend') }}
cvt.u64.u32 addl, addend;
add.u64 sum, sum, addl;
{{endfor}}
def test(self, ctx):
{
.reg .u32 addr, offset;
{{ get_gtid('offset') }}
mov.u32 addr, mwc_rng_test_sums;
mad.lo.u32 addr, offset, 8, addr;
st.global.u64 [addr], sum;
}
""" % self.rounds)
def call(self, ctx):
# Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
sums = np.zeros(ctx.threads, np.uint64)
print "Running states forward %d rounds on CPU" % self.rounds
ctime = time.time()
for i in range(self.rounds):
states = fullstates & 0xffffffff
carries = fullstates >> 32
fullstates = mults * states + carries
sums = sums + (fullstates & 0xffffffff)
ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime
print "Same thing on the device"
func = ctx.mod.get_function('MWC_RNG_test')
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
print "Done on device, took %g seconds" % dtime
print "Comparing states and sums..."
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all():
print "State discrepancy"
print dfullstates
print fullstates
#return False
def launch(self, ctx):
if self.mults
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
if not (dsums == sums).all():
print "Sum discrepancy"
print dsums
print sums
return False
return True
def main(genome_path):
ctx = LaunchContext(block=(256,1,1), grid=(64,1))
ctx.compile([MWCRNGTest])
ctx.instances[MWCRNG].set_up(ctx)
ctx.instances[MWCRNGTest].call(ctx)