Code builder, RNG working

This commit is contained in:
Steven Robertson 2010-08-28 00:28:00 -04:00
parent a23ebdcf5f
commit 907cbb273f

397
main.py
View File

@ -11,6 +11,7 @@
import os import os
import sys import sys
import time
import ctypes import ctypes
import struct import struct
@ -32,57 +33,152 @@ import numpy as np
from fr0stlib import pyflam3 from fr0stlib import pyflam3
# PTX header and functions used for debugging. def ppr_ptx(src):
prelude = """ # TODO: Add variable realignment
.version 2.0 indent = 0
.target sm_20 out = []
for line in [l.strip() for l in src.split('\n')]:
if not line:
continue
if len(line.split()) == 1 and line.endswith(':'):
out.append(line)
continue
if '}' in line and '{' not in line:
indent -= 1
out.append(' ' * (indent * 4) + line)
if '{' in line and '}' not in line:
indent += 1
return '\n'.join(out)
.func (.reg .u32 $ret) get_gtid () def multisub(tmpl, subs):
{ while '{{' in tmpl:
.reg .u16 tmp; tmpl = tempita.Template(tmpl).substitute(subs)
.reg .u32 cta, ncta, tid, gtid; return tmpl
mov.u16 tmp, %ctaid.x;
cvt.u32.u16 cta, tmp;
mov.u16 tmp, %ntid.x;
cvt.u32.u16 ncta, tmp;
mul24.lo.u32 gtid, cta, ncta;
mov.u16 tmp, %tid.x;
cvt.u32.u16 tid, tmp;
add.u32 gtid, gtid, tid;
mov.b32 $ret, gtid;
ret;
}
.entry write_to_buffer ( .param .u32 bufbase )
{
.reg .u32 base, gtid, off;
ld.param.u32 base, [bufbase];
call.uni (off), get_gtid, ();
mad24.lo.u32 base, off, 4, base;
st.volatile.global.b32 [base], off;
}
"""
class CUGenome(pyflam3.Genome): class CUGenome(pyflam3.Genome):
def _render(self, frame, trans): def _render(self, frame, trans):
obuf = (ctypes.c_ubyte * ((3+trans)*self.width*self.height))() obuf = (ctypes.c_ubyte * ((3+trans)*self.width*self.height))()
stats = pyflam3.RenderStats() stats = pyflam3.RenderStats()
pyflam3.flam3_render(ctypes.byref(frame), obuf, pyflam3.flam3_field_both, pyflam3.flam3_render(ctypes.byref(frame), obuf,
pyflam3.flam3_field_both,
trans+3, trans, ctypes.byref(stats)) trans+3, trans, ctypes.byref(stats))
return obuf, stats, frame return obuf, stats, frame
class LaunchContext(self): class LaunchContext(object):
def __init__(self, seed=None): """
self.block, self.grid, self.threads = None, None, None Context collecting the information needed to create, run, and gather the
self.stream = cuda.Stream() results of a device computation.
self.rand = mtrand.RandomState(seed)
def set_size(self, block, grid): To create the fastest device code across multiple device families, this
self.block, self.grid = block, grid context may decide to iteratively refine the final PTX by regenerating
self.threads = reduce(lambda a, b: a*b, self.block + self.grid) and recompiling it several times to optimize certain parameters of the
launch, such as the distribution of threads throughout the device.
The properties of this device which are tuned are listed below. Any PTX
fragments which use this information must emit valid PTX for any state
given below, but the PTX is only required to actually run with the final,
fixed values of all tuned parameters below.
`block`: 3-tuple of (x,y,z); dimensions of each CTA.
`grid`: 2-tuple of (x,y); dimensions of the grid of CTAs.
`threads`: Number of active threads on device as a whole.
`mod`: Final compiled module. Unavailable during assembly.
"""
def __init__(self, block=(1,1,1), grid=(1,1), seed=None, tests=False):
self.block, self.grid, self.tests = block, grid, tests
self.stream = cuda.Stream()
self.rand = np.random.mtrand.RandomState(seed)
@property
def threads(self):
return reduce(lambda a, b: a*b, self.block + self.grid)
def _deporder(self, unsorted_instances, instance_map):
# Do a DFS on the mapping of PTXFragment types to instances, returning
# a list of instances ordered such that nothing depends on anything
# before it in the list
seen = {}
def rec(inst):
if inst in seen: return seen[inst]
deps = filter(lambda d: d is not inst, map(instance_map.get,
callable(inst.deps) and inst.deps(self) or inst.deps))
return seen.setdefault(inst, 1+max([0]+map(rec, deps)))
map(rec, unsorted_instances)
return sorted(unsorted_instances, key=seen.get)
def _safeupdate(self, dst, src):
for key, val in src.items():
if key in dst:
raise KeyError("Duplicate key %s" % key)
dst[key] = val
def assemble(self, entries):
# Get a property, dealing with the callable-or-data thing
def pget(prop):
if callable(prop): return prop(self)
return prop
instances = {}
entries_unvisited = list(entries)
tests = set()
parsed_entries = []
while entries_unvisited:
ent = entries_unvisited.pop(0)
seen, unvisited = set(), [ent]
while unvisited:
frag = unvisited.pop(0)
seen.add(frag)
inst = instances.setdefault(frag, frag())
for dep in pget(inst.deps):
if dep not in seen:
unvisited.append(dep)
tmpl_namespace = {'ctx': self}
entry_start, entry_end = [], []
for inst in self._deporder(map(instances.get, seen), instances):
self._safeupdate(tmpl_namespace, pget(inst.subs))
entry_start.append(pget(inst.entry_start))
entry_end.append(pget(inst.entry_end))
entry_start_tmpl = '\n'.join(filter(None, entry_start))
entry_end_tmpl = '\n'.join(filter(None, reversed(entry_end)))
name, args, body = pget(instances[ent].entry)
tmpl_namespace.update({'_entry_name_': name, '_entry_args_': args,
'_entry_body_': body, '_entry_start_': entry_start_tmpl,
'_entry_end_': entry_end_tmpl})
entry_tmpl = """
.entry {{ _entry_name_ }} ({{ _entry_args_ }})
{
{{ _entry_start_ }}
{{ _entry_body_ }}
{{ _entry_end_ }}
}
"""
parsed_entries.append(multisub(entry_tmpl, tmpl_namespace))
prelude = []
tmpl_namespace = {'ctx': self}
for inst in self._deporder(instances.values(), instances):
prelude.append(pget(inst.prelude))
self._safeupdate(tmpl_namespace, pget(inst.subs))
tmpl_namespace['_prelude_'] = '\n'.join(filter(None, prelude))
tmpl_namespace['_entries_'] = '\n\n'.join(parsed_entries)
tmpl = "{{ _prelude_ }}\n\n{{ _entries_ }}\n"
return instances, multisub(tmpl, tmpl_namespace)
def compile(self, entries):
# For now, do no optimization.
self.instances, self.src = self.assemble(entries)
self.src = ppr_ptx(self.src)
try:
self.mod = cuda.module_from_buffer(self.src)
except (cuda.CompileError, cuda.RuntimeError), e:
print "Aww, dang, compile error. Here's the source:"
print '\n'.join(["%03d %s" % (i+1, l)
for (i, l) in enumerate(self.src.split('\n'))])
raise e
class PTXFragment(object): class PTXFragment(object):
""" """
@ -99,14 +195,19 @@ class PTXFragment(object):
Template code will be processed recursively until all "{{" instances have Template code will be processed recursively until all "{{" instances have
been replaced, using the same namespace each time. been replaced, using the same namespace each time.
Note that any method which does not depend on 'ctx' can be replaced with
an instance of the appropriate return type. So, for example, the 'deps'
property can be a flat list instead of a function.
""" """
def deps(self, ctx): def deps(self, ctx):
""" """
Returns a list of PTXFragment objects on which this object depends Returns a list of PTXFragment objects on which this object depends
for successful compilation. Circular dependencies are forbidden. for successful compilation. Circular dependencies are forbidden,
but multi-level dependencies should be fine.
""" """
return [] return [DeviceHelpers]
def subs(self, ctx): def subs(self, ctx):
""" """
@ -124,33 +225,40 @@ class PTXFragment(object):
""" """
return "" return ""
def entryPrelude(self, ctx): def entry_start(self, ctx):
""" """
Returns a template string that should be inserted at the top of any Returns a template string that should be inserted at the top of any
entry point which depends on this method. The entry prelude of all entry point which depends on this method. The entry starts of all
deps will be inserted above this entry prelude. deps will be inserted above this entry prelude.
""" """
return "" return ""
def setUp(self, ctx): def entry_end(self, ctx):
"""
As above, but at the end of the calling function, and with the order
reversed (all dependencies will be inserted after this).
"""
return ""
def set_up(self, ctx):
""" """
Do start-of-stream initialization, such as copying data to the device. Do start-of-stream initialization, such as copying data to the device.
""" """
pass pass
def test(self, ctx): # A list of PTXTest classes which will test this fragment
""" tests = []
Perform device tests. Returns True on success, False on failure,
or raises an exception.
"""
return True
class PTXEntryPoint(PTXFragment): class PTXEntryPoint(PTXFragment):
# Human-readable entry point name
name = ""
def entry(self, ctx): def entry(self, ctx):
""" """
Returns a template string corresponding to a PTX entry point. Returns a 3-tuple of (name, args, body), which will be assembled into
a function.
""" """
pass raise NotImplementedError
def call(self, ctx): def call(self, ctx):
""" """
@ -159,71 +267,190 @@ class PTXEntryPoint(PTXFragment):
""" """
pass pass
class PTXTest(PTXEntryPoint):
"""PTXTests are semantically equivalent to PTXEntryPoints, but they
differ slightly in use. In particular:
* The "name" property should describe the test being performed,
* ctx.stream will be synchronized before 'call' is run, and should be
synchronized afterwards (i.e. sync it yourself or don't use it),
* call() should return True to indicate that a test passed, or
False (or raise an exception) if it failed.
"""
pass
class DeviceHelpers(PTXFragment): class DeviceHelpers(PTXFragment):
"""This one's included by default, no need to depend on it""" prelude = ".version 2.1\n.target sm_20\n\n"
def _get_gtid(self, dst):
return "{\n// Load GTID into " + dst + """
.reg .u16 tmp;
.reg .u32 cta, ncta, tid, gtid;
mov.u16 tmp, %ctaid.x;
cvt.u32.u16 cta, tmp;
mov.u16 tmp, %ntid.x;
cvt.u32.u16 ncta, tmp;
mul.lo.u32 gtid, cta, ncta;
mov.u16 tmp, %tid.x;
cvt.u32.u16 tid, tmp;
add.u32 gtid, gtid, tid;
mov.b32 """ + dst + ", gtid;\n}"
def subs(self, ctx): def subs(self, ctx):
return { return {
'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32', 'PTRT': ctypes.sizeof(ctypes.c_void_p) == 8 and '.u64' or '.u32',
'get_gtid': self._get_gtid
} }
class MWCRandGen(PTXFragment): class MWCRNG(PTXFragment):
_prelude = """
.const {{PTRT}} mwc_rng_mults_p;
.const {{PTRT}} mwc_rng_seeds_p;
"""
def __init__(self): def __init__(self):
if not os.path.isfile(os.path.join(os.path.dirname(__FILE__), if not os.path.isfile('primes.bin'):
'primes.bin')):
raise EnvironmentError('primes.bin not found') raise EnvironmentError('primes.bin not found')
def prelude(self): prelude = """
return self._prelude .global .u32 mwc_rng_mults[{{ctx.threads}}];
.global .u64 mwc_rng_state[{{ctx.threads}}];"""
def setUp(self, ctx): def _next_b32(self, dreg):
# TODO: make sure PTX optimizes away superfluous move instrs
return """
{
// MWC next b32
.reg .u64 mwc_out;
cvt.u64.u32 mwc_out, mwc_car;
mad.wide.u32 mwc_out, mwc_st, mwc_mult, mwc_out;
mov.b64 {mwc_st, mwc_car}, mwc_out;
mov.u32 %s, mwc_st;
}
""" % dreg
def subs(self, ctx):
return {'mwc_next_b32': self._next_b32}
entry_start = """
.reg .u32 mwc_st, mwc_mult, mwc_car;
{
// MWC load multipliers and RNG states
.reg .u32 mwc_off, mwc_addr;
{{ get_gtid('mwc_off') }}
mov.u32 mwc_addr, mwc_rng_mults;
mad.lo.u32 mwc_addr, mwc_off, 4, mwc_addr;
ld.global.u32 mwc_mult, [mwc_addr];
mov.u32 mwc_addr, mwc_rng_state;
mad.lo.u32 mwc_addr, mwc_off, 8, mwc_addr;
ld.global.v2.u32 {mwc_st, mwc_car}, [mwc_addr];
}
"""
entry_end = """
{
// MWC save states
.reg .u32 mwc_addr, mwc_off;
{{ get_gtid('mwc_off') }}
mov.u32 mwc_addr, mwc_rng_state;
mad.lo.u32 mwc_addr, mwc_off, 8, mwc_addr;
st.global.v2.u32 [mwc_addr], {mwc_st, mwc_car};
}
"""
def set_up(self, ctx):
# Load raw big-endian u32 multipliers from primes.bin. # Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp: with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B') dt = np.dtype(np.uint32).newbyteorder('B')
mults = np.frombuffer(primefp.read(), dtype=dt) mults = np.frombuffer(primefp.read(), dtype=dt)
# Randomness in choosing multipliers is good, but larger multipliers # Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise. # have longer periods, which is also good. This is a compromise.
ctx.rand.shuffle(mults[:ctx.threads*4]) # TODO: fix mutability, enable shuffle here
#ctx.rand.shuffle(mults[:ctx.threads*4])
# Copy multipliers and seeds to the device # Copy multipliers and seeds to the device
devmp, devml = ctx.mod.get_global('mwc_rng_mults') multdp, multl = ctx.mod.get_global('mwc_rng_mults')
cuda.memcpy_htod_async(devmp, mults.tostring()[:devml], ctx.stream) # TODO: get async to work
devsp, devsl = ctx.mod.get_global('mwc_rng_seeds') #cuda.memcpy_htod_async(multdp, mults.tostring()[:multl], ctx.stream)
cuda.memcpy_htod_async(devsp, ctx.rand.bytes(devsl), ctx.stream) cuda.memcpy_htod(multdp, mults.tostring()[:multl])
statedp, statel = ctx.mod.get_global('mwc_rng_state')
#cuda.memcpy_htod_async(statedp, ctx.rand.bytes(statel), ctx.stream)
cuda.memcpy_htod(statedp, ctx.rand.bytes(statel))
def _next_b32(self, dreg): def tests(self, ctx):
return """ return [MWCRNGTest]
mul.wide.u32 mwc_rng_
mul.wide.u32
class MWCRNGTest(PTXTest):
name = "MWC RNG sum-of-threads test"
deps = [MWCRNG]
rounds = 200
prelude = ".global .u64 mwc_rng_test_sums[{{ctx.threads}}];"
def templates(self, ctx): def entry(self, ctx):
return {'mwc_next_b32', self._next_b32} return ('MWC_RNG_test', '', """
.reg .u64 sum, addl;
.reg .u32 addend;
mov.u64 sum, 0;
{{for round in range(%d)}}
{{ mwc_next_b32('addend') }}
cvt.u64.u32 addl, addend;
add.u64 sum, sum, addl;
{{endfor}}
def test(self, ctx): {
.reg .u32 addr, offset;
{{ get_gtid('offset') }}
mov.u32 addr, mwc_rng_test_sums;
mad.lo.u32 addr, offset, 8, addr;
st.global.u64 [addr], sum;
}
""" % self.rounds)
def call(self, ctx):
# Get current multipliers and seeds from the device
multdp, multl = ctx.mod.get_global('mwc_rng_mults')
mults = cuda.from_device(multdp, ctx.threads, np.uint32)
statedp, statel = ctx.mod.get_global('mwc_rng_state')
fullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
sums = np.zeros(ctx.threads, np.uint64)
print "Running states forward %d rounds on CPU" % self.rounds
ctime = time.time()
for i in range(self.rounds):
states = fullstates & 0xffffffff
carries = fullstates >> 32
fullstates = mults * states + carries
sums = sums + (fullstates & 0xffffffff)
ctime = time.time() - ctime
print "Done on host, took %g seconds" % ctime
print "Same thing on the device"
func = ctx.mod.get_function('MWC_RNG_test')
dtime = func(block=ctx.block, grid=ctx.grid, time_kernel=True)
print "Done on device, took %g seconds" % dtime
print "Comparing states and sums..."
dfullstates = cuda.from_device(statedp, ctx.threads, np.uint64)
if not (dfullstates == fullstates).all():
print "State discrepancy"
print dfullstates
print fullstates
#return False
sumdp, suml = ctx.mod.get_global('mwc_rng_test_sums')
dsums = cuda.from_device(sumdp, ctx.threads, np.uint64)
def launch(self, ctx): if not (dsums == sums).all():
if self.mults print "Sum discrepancy"
print dsums
print sums
return False
return True
def main(genome_path): def main(genome_path):
ctx = LaunchContext(block=(256,1,1), grid=(64,1))
ctx.compile([MWCRNGTest])
ctx.instances[MWCRNG].set_up(ctx)
ctx.instances[MWCRNGTest].call(ctx)