Some amount of dynamic rendering

This commit is contained in:
Steven Robertson 2011-04-30 16:40:16 -04:00
parent 1302f31ec7
commit 088299423e
5 changed files with 165 additions and 137 deletions

View File

@ -1,19 +1,9 @@
"""
Contains the PTX fragments which will drive the device.
Contains the PTX fragments which will drive the device, and helper functions
to combine those fragments.
"""
# Basic headers, utility functions, and so on
base = """
#include<cuda.h>
#include<stdint.h>
// TODO: use launch parameter preconfig to eliminate unnecessary parts
__device__
uint32_t gtid() {
return threadIdx.x + blockDim.x *
(threadIdx.y + blockDim.y *
(threadIdx.z + blockDim.z *
(blockIdx.x + (gridDim.x * blockIdx.y))));
}
"""
import util
import mwc
import iter

View File

@ -7,62 +7,86 @@ from pycuda.driver import In, Out, InOut
from pycuda.compiler import SourceModule
import numpy as np
from cuburn import code
from cuburn.code import mwc
from cuburn.code.util import *
src = r"""
#define FUSE 20
#define MAXOOB 10
import tempita
typedef struct {
// Number of iterations to perform, *per thread*.
uint32_t niters;
class IterCode(HunkOCode):
def __init__(self, features):
self.features = features
self.packer = DataPacker('iter_info')
iterbody = self._iterbody()
bodies = [self._xfbody(i,x) for i,x in enumerate(self.features.xforms)]
bodies.append(iterbody)
self.defs = '\n'.join(bodies)
// Number of accumulators per row and column in the accum buffer
uint32_t accwidth, accheight;
} iter_info;
def _xfbody(self, xfid, xform):
px = self.packer.view('info', 'xf%d_' % xfid)
px.sub('xf', 'cp.xforms[%d]' % xfid)
tmpl = tempita.Template("""
__device__
void apply_xf{{xfid}}(float *ix, float *iy, float *icolor,
const iter_info *info) {
float tx, ty, ox = *ix, oy = *iy;
{{apply_affine('ox', 'oy', 'tx', 'ty', px, 'xf.c', 'pre')}}
// tiny little TODO: variations
*ix = tx;
*iy = ty;
float csp = {{px.get('xf.color_speed')}};
*icolor = *icolor * (1.0f - csp) + {{px.get('xf.color')}} * csp;
};
""")
g = dict(globals())
g.update(locals())
return tmpl.substitute(g)
def _iterbody(self):
tmpl = tempita.Template("""
__global__
void silly(mwc_st *msts, iter_info *infos, float *accbuf, float *denbuf) {
void iter(mwc_st *msts, const iter_info *infos, float *accbuf, float *denbuf) {
mwc_st rctx = msts[gtid()];
iter_info *info = &(infos[blockIdx.x]);
const iter_info *info = &(infos[blockIdx.x]);
float consec_bad = -FUSE;
float nsamps = info->niters;
int consec_bad = -{{features.fuse}};
int nsamps = 500;
float x, y, color;
x = mwc_next_11(&rctx);
y = mwc_next_11(&rctx);
color = mwc_next_01(&rctx);
while (nsamps > 0.0f) {
while (nsamps > 0) {
float xfsel = mwc_next_01(&rctx);
x *= 0.5f;
y *= 0.5f;
color *= 0.5f;
if (xfsel < 0.33f) {
color += 0.25f;
x += 0.5f;
} else if (xfsel < 0.66f) {
color += 0.5f;
y += 0.5f;
{{for xfid, xform in enumerate(features.xforms)}}
if (xfsel < {{packer.get('cp.norm_density[%d]' % xfid)}}) {
apply_xf{{xfid}}(&x, &y, &color, info);
} else
{{endfor}}
{
denbuf[0] = xfsel;
break; // TODO: fail here
}
if (consec_bad < 0.0f) {
if (consec_bad < 0) {
consec_bad++;
continue;
}
if (x <= -1.0f || x >= 1.0f || y <= -1.0f || y >= 1.0f
|| consec_bad < 0.0f) {
|| consec_bad < 0) {
consec_bad++;
if (consec_bad > MAXOOB) {
if (consec_bad > {{features.max_oob}}) {
x = mwc_next_11(&rctx);
y = mwc_next_11(&rctx);
color = mwc_next_01(&rctx);
consec_bad = -FUSE;
consec_bad = -{{features.fuse}};
}
continue;
}
@ -80,26 +104,28 @@ void silly(mwc_st *msts, iter_info *infos, float *accbuf, float *denbuf) {
nsamps--;
}
}
"""
""")
return tmpl.substitute(
features = self.features,
packer = self.packer.view('info'))
def silly():
mod = SourceModule(code.base + mwc.src + src)
def silly(features, cp):
abuf = np.zeros((512, 512, 4), dtype=np.float32)
dbuf = np.zeros((512, 512), dtype=np.float32)
seeds = mwc.build_mwc_seeds(512 * 24, seed=5)
seeds = mwc.MWC.make_seeds(512 * 24)
info = np.zeros(3, dtype=np.uint32)
info[0] = 5000
info[1] = 512
info[2] = 512
info = np.repeat([info], 24, axis=0)
iter = IterCode(features)
code = assemble_code(BaseCode, mwc.MWC, iter, iter.packer)
print code
mod = SourceModule(code)
fun = mod.get_function("silly")
info = iter.packer.pack(cp=cp)
print info
fun = mod.get_function("iter")
fun(InOut(seeds), In(info), InOut(abuf), InOut(dbuf),
block=(512,1,1), grid=(24,1), time_kernel=True)
block=(512,1,1), grid=(1,1), time_kernel=True)
print abuf
print dbuf
print sum(dbuf)
return abuf, dbuf

View File

@ -2,23 +2,21 @@
The multiply-with-carry random number generator.
"""
import time
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import numpy as np
import tempita
from jinja2 import Template
from cuburn.code.util import *
from cuburn import code
src = r"""
class MWC(HunkOCode):
decls = """
typedef struct {
uint32_t mul;
uint32_t state;
uint32_t carry;
} mwc_st;
"""
defs = r"""
__device__ uint32_t mwc_next(mwc_st *st) {
asm("{\n\t.reg .u64 val;\n\t"
"cvt.u64.u32 val, %0;\n\t"
@ -35,10 +33,38 @@ __device__ float mwc_next_01(mwc_st *st) {
__device__ float mwc_next_11(mwc_st *st) {
return ((int32_t) mwc_next(st)) * (1.0f / 2147483648.0f);
}
"""
testsrc = code.base + src + """
@staticmethod
def make_seeds(nthreads, host_seed=None):
if host_seed:
rand = np.random.RandomState(host_seed)
else:
rand = np.random
# Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B')
mults = np.frombuffer(primefp.read(), dtype=dt)
# Create the seed structures. TODO: check that struct is 4-byte aligned
seeds = np.empty((3, nthreads), dtype=np.uint32, order='F')
# Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise.
mults = np.array(mults[:nthreads*4])
rand.shuffle(mults)
seeds[0][:] = mults[:nthreads]
# Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0
seeds[1] = rand.randint(1, 0xffffffff, size=nthreads)
seeds[2] = rand.randint(1, 0xffffffff, size=nthreads)
return seeds
class MWCTest(HunkOCode):
defs = """
__global__ void test_mwc(mwc_st *msts, uint64_t *sums, float nrounds) {
mwc_st rctx = msts[gtid()];
uint64_t sum = 0;
@ -48,69 +74,45 @@ __global__ void test_mwc(mwc_st *msts, uint64_t *sums, float nrounds) {
}
"""
def build_mwc_seeds(nthreads, seed=None):
if seed:
rand = np.random.RandomState(seed)
else:
rand = np.random
@classmethod
def test_mwc(cls, rounds=5000, nblocks=64, blockwidth=512):
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import time
# Load raw big-endian u32 multipliers from primes.bin.
with open('primes.bin') as primefp:
dt = np.dtype(np.uint32).newbyteorder('B')
mults = np.frombuffer(primefp.read(), dtype=dt)
nthreads = blockwidth * nblocks
seeds = MWC.make_seeds(nthreads, host_seed = 5)
dseeds = cuda.to_device(seeds)
# Create the seed structures. TODO: check that struct is 4-byte aligned
seeds = np.empty((3, nthreads), dtype=np.uint32, order='F')
mod = SourceModule(assemble_code(BaseCode, MWC, cls))
# Randomness in choosing multipliers is good, but larger multipliers
# have longer periods, which is also good. This is a compromise.
mults = np.array(mults[:nthreads*4])
rand.shuffle(mults)
seeds[0][:] = mults[:nthreads]
for trial in range(2):
print "Trial %d, on CPU: " % trial,
sums = np.zeros(nthreads, dtype=np.uint64)
ctime = time.time()
mults = seeds[0].astype(np.uint64)
states = seeds[1]
carries = seeds[2]
for i in range(rounds):
step = np.frombuffer((mults * states + carries).data,
dtype=np.uint32).reshape((2, nthreads), order='F')
states[:] = step[0]
carries[:] = step[1]
sums += states
# Intentionally excludes both 0 and (2^32-1), as they can lead to
# degenerate sequences of period 0
seeds[1] = rand.randint(1, 0xffffffff, size=nthreads)
seeds[2] = rand.randint(1, 0xffffffff, size=nthreads)
ctime = time.time() - ctime
print "Took %g seconds." % ctime
return seeds
def test_mwc():
rounds = 5000
nblocks = 64
nthreads = 512 * nblocks
seeds = build_mwc_seeds(nthreads, seed = 5)
dseeds = cuda.to_device(seeds)
mod = SourceModule(testsrc)
for trial in range(2):
print "Trial %d, on CPU: " % trial,
sums = np.zeros(nthreads, dtype=np.uint64)
ctime = time.time()
mults = seeds[0].astype(np.uint64)
states = seeds[1]
carries = seeds[2]
for i in range(rounds):
step = np.frombuffer((mults * states + carries).data,
dtype=np.uint32).reshape((2, nthreads), order='F')
states[:] = step[0]
carries[:] = step[1]
sums += states
ctime = time.time() - ctime
print "Took %g seconds." % ctime
print "Trial %d, on device: " % trial,
dsums = cuda.mem_alloc(8*nthreads)
fun = mod.get_function("test_mwc")
dtime = fun(dseeds, dsums, np.float32(rounds),
block=(512,1,1), grid=(nblocks,1), time_kernel=True)
print "Took %g seconds." % dtime
dsums = cuda.from_device(dsums, nthreads, np.uint64)
if not np.all(np.equal(sums, dsums)):
print "Sum discrepancy!"
print sums
print dsums
print "Trial %d, on device: " % trial,
dsums = cuda.mem_alloc(8*nthreads)
fun = mod.get_function("test_mwc")
dtime = fun(dseeds, dsums, np.float32(rounds),
block=(blockwidth,1,1), grid=(nblocks,1),
time_kernel=True)
print "Took %g seconds." % dtime
dsums = cuda.from_device(dsums, nthreads, np.uint64)
if not np.all(np.equal(sums, dsums)):
print "Sum discrepancy!"
print sums
print dsums

View File

@ -14,7 +14,17 @@ from cuburn.variations import Variations
Point = lambda x, y: np.array([x, y], dtype=np.double)
class Genome(pyflam3.Genome):
pass
@classmethod
def from_string(cls, *args, **kwargs):
gnms = super(Genome, cls).from_string(*args, **kwargs)
for g in gnms: g._init()
return gnms
def _init(self):
self.xforms = [self.xform[i] for i in range(self.num_xforms)]
dens = np.array([x.density for x in self.xforms])
dens /= np.sum(dens)
self.norm_density = [np.sum(dens[:i+1]) for i in range(len(dens))]
class XForm(object):
"""
@ -99,7 +109,7 @@ class Frame(object):
cp.camera = Camera(self._frame, cp, filters)
cp.nsamples = (cp.camera.sample_density *
center.width * center.height) / ncps
cp.xforms = XForm.parse(cp)
print "Expected writes:", (
cp.camera.sample_density * center.width * center.height)
@ -190,9 +200,10 @@ class Features(object):
"""
# Constant parameters which control handling of out-of-frame samples:
# Number of iterations to iterate without write after new point
fuse = 2
# Maximum consecutive out-of-frame points before picking new point
max_bad = 3
fuse = 20
# Maximum consecutive out-of-bounds points before picking new point
max_oob = 10
max_nxforms = 12
# Height of the texture pallete which gets uploaded to the GPU (assuming
# that palette-from-texture is enabled). For most genomes, this doesn't
@ -205,7 +216,6 @@ class Features(object):
any = lambda l: bool(filter(None, map(l, genomes)))
self.max_ntemporal_samples = max(
[cp.nbatches * cp.ntemporal_samples for cp in genomes])
self.camera_rotation = any(lambda cp: cp.rotate)
self.non_box_temporal_filter = genomes[0].temporal_filter_type
self.palette_mode = genomes[0].palette_mode and "linear" or "nearest"
@ -214,6 +224,7 @@ class Features(object):
"number of xforms! (try running through flam3-genome first)")
self.xforms = [XFormFeatures([x[i] for x in xforms], i)
for i in range(len(xforms[0]))]
self.nxforms = len(self.xforms)
if any(lambda cp: cp.final_xform_enable):
raise NotImplementedError("Final xform")

View File

@ -24,16 +24,15 @@ import pyglet
import pycuda.autoinit
from cuburn.render import *
from cuburn.code.mwc import test_mwc
from cuburn.code.mwc import MWCTest
from cuburn.code.iter import silly
def main(args):
#MWCTest.test_mwc()
with open(args[-1]) as fp:
genomes = Genome.from_string(fp.read())
anim = Animation(genomes)
accum, den = silly()
accum, den = silly(anim.features, genomes[0])
if False:
bins = anim.render_frame()