Color palette (sort of)

This commit is contained in:
Steven Robertson 2011-05-01 15:23:45 -04:00
parent a43973f0ff
commit b710de4865

View File

@ -2,7 +2,7 @@
The main iteration loop.
"""
from ctypes import byref
from ctypes import byref, memset, sizeof
import pycuda.driver as cuda
from pycuda.driver import In, Out, InOut
@ -25,6 +25,11 @@ class IterCode(HunkOCode):
bodies.append(iterbody)
self.defs = '\n'.join(bodies)
decls = """
// Note: for normalized lookups, uchar4 actually returns floats
texture<uchar4, cudaTextureType2D, cudaReadModeNormalizedFloat> palTex;
"""
def _xfbody(self, xfid, xform):
px = self.packer.view('info', 'xf%d_' % xfid)
px.sub('xf', 'cp.xforms[%d]' % xfid)
@ -65,7 +70,7 @@ void iter(mwc_st *msts, const iter_info *infos, float *accbuf, float *denbuf) {
const iter_info *info = &(infos[blockIdx.x]);
int consec_bad = -{{features.fuse}};
int nsamps = 500;
int nsamps = 2560;
float x, y, color;
x = mwc_next_11(&rctx);
@ -106,11 +111,14 @@ void iter(mwc_st *msts, const iter_info *infos, float *accbuf, float *denbuf) {
// TODO: dither?
int i = ((int)((y + 1.0f) * 255.0f) * 512)
+ (int)((x + 1.0f) * 255.0f);
accbuf[i*4] += color < 0.5f ? (1.0f - 2.0f * color) : 0.0f;
accbuf[i*4+1] += 1.0f - 2.0f * fabsf(0.5f - color);
accbuf[i*4+2] += color > 0.5f ? color * 2.0f - 1.0f : 0.0f;
accbuf[i*4+3] += 1.0f;
// since info was declared const, C++ barfs unless it's loaded first
float cp_step_frac = {{packer.get('cp_step_frac')}};
float4 outcol = tex2D(palTex, cp_step_frac, color);
accbuf[i*4] += outcol.x;
accbuf[i*4+1] += outcol.y;
accbuf[i*4+2] += outcol.z;
accbuf[i*4+3] += outcol.w;
denbuf[i] += 1.0f;
}
@ -137,16 +145,33 @@ def silly(features, cps):
cps_as_array[i] = cp
cp = Genome()
memset(byref(cp), 0, sizeof(cp))
infos = []
# TODO: move this into a common function
pal = np.empty((16, 256, 4), dtype=np.uint8)
sampAt = [int(i/15.*(nsteps-1)) for i in range(16)]
for n in range(nsteps):
flam3_interpolate(cps_as_array, 2, (n - nsteps/2) * 0.001, 0, byref(cp))
cp._init()
infos.append(iter.packer.pack(cp=cp))
if n in sampAt:
pidx = sampAt.index(n)
for i, e in enumerate(cp.palette.entries):
pal[pidx][i] = np.uint8(np.array(e.color) * 255.0)
infos.append(iter.packer.pack(cp=cp, cp_step_frac=float(n)/nsteps))
infos = np.concatenate(infos)
dpal = cuda.make_multichannel_2d_array(pal, 'C')
tref = mod.get_texref('palTex')
tref.set_array(dpal)
tref.set_format(cuda.array_format.UNSIGNED_INT8, 4)
tref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
fun = mod.get_function("iter")
fun(InOut(seeds), In(infos), InOut(abuf), InOut(dbuf),
t = fun(InOut(seeds), In(infos), InOut(abuf), InOut(dbuf),
block=(512,1,1), grid=(nsteps,1), time_kernel=True)
print "Completed render in %g seconds" % t
return abuf, dbuf