Palette interpolation on device

This commit is contained in:
Steven Robertson 2011-10-25 22:56:19 -04:00
parent e793527c29
commit 376cd752d6
4 changed files with 102 additions and 16 deletions

View File

@ -313,6 +313,33 @@ void test_cr(const float *times, const float *knots, const float *t, float *r) {
r[i] = catmull_rom(times, knots, t[i]); r[i] = catmull_rom(times, knots, t[i]);
} }
__global__
void interp_palette_hsv(uchar4 *outs, const float *times, const float4 *sources,
float tstart, float tstep) {
float time = tstart + blockIdx.x * tstep;
int idx = fmaxf(bitwise_binsearch(times, time), 1);
float4 left = sources[blockDim.x * (idx - 1) + threadIdx.x];
float4 right = sources[blockDim.x * (idx) + threadIdx.x];
float lf = (times[idx] - time) / (times[idx] - times[idx-1]);
float rf = 1.0f - lf;
float3 lhsv = rgb2hsv(make_float3(left.x, left.y, left.z));
float3 rhsv = rgb2hsv(make_float3(right.x, right.y, right.z));
float3 hsv;
hsv.x = lhsv.x * lf + rhsv.x * rf;
hsv.y = lhsv.y * lf + rhsv.y * rf;
hsv.z = lhsv.z * lf + rhsv.z * rf;
float3 rgb = hsv2rgb(hsv);
uchar4 out;
out.x = rgb.x * 255.0f;
out.y = rgb.y * 255.0f;
out.z = rgb.z * 255.0f;
out.w = 255.0f * (left.z * lf + right.z * rf);
outs[blockDim.x * blockIdx.x + threadIdx.x] = out;
}
""") """)

View File

@ -35,6 +35,11 @@ class BaseCode(HunkOCode):
#include<cuda.h> #include<cuda.h>
#include<stdint.h> #include<stdint.h>
#include<stdio.h> #include<stdio.h>
"""
decls = """
float3 rgb2hsv(float3 rgb);
float3 hsv2rgb(float3 hsv);
""" """
defs = r""" defs = r"""
@ -134,7 +139,49 @@ void write_half(float &xy, float x, float y, float den) {
: "=f"(xy) : "f"(x), "f"(y), "f"(den)); : "=f"(xy) : "f"(x), "f"(y), "f"(den));
} }
__device__
float3 rgb2hsv(float3 rgb) {
float M = fmaxf(fmaxf(rgb.x, rgb.y), rgb.z);
float m = fminf(fminf(rgb.x, rgb.y), rgb.z);
float C = M - m;
float s = M > 0.0f ? C / M : 0.0f;
float h;
if (s != 0.0f) {
C = 1.0f / C;
float rc = (M - rgb.x) * C;
float gc = (M - rgb.y) * C;
float bc = (M - rgb.z) * C;
if (rgb.x == M) h = bc - gc;
else if (rgb.y == M) h = 2 + gc - bc;
else h = 4 + gc - rc;
if (h < 0) h += 6;
}
return make_float3(h, s, M);
}
__device__
float3 hsv2rgb(float3 hsv) {
float whole = floorf(hsv.x);
float frac = hsv.x - whole;
float val = hsv.z;
float min = val * (1 - hsv.y);
float mid = val * (1 - (hsv.y * frac));
float alt = val * (1 - (hsv.y * (1 - frac)));
float3 out;
if (whole == 0.0f) { out.x = val; out.y = alt; out.z = min; }
else if (whole == 1.0f) { out.x = mid; out.y = val; out.z = min; }
else if (whole == 2.0f) { out.x = min; out.y = val; out.z = alt; }
else if (whole == 3.0f) { out.x = min; out.y = mid; out.z = val; }
else if (whole == 4.0f) { out.x = alt; out.y = min; out.z = val; }
else { out.x = val; out.y = min; out.z = mid; }
return out;
}
""" """
@staticmethod @staticmethod

View File

@ -1,4 +1,5 @@
import json import json
import base64
import numpy as np import numpy as np
import scipy.interpolate import scipy.interpolate
from cuburn import affine from cuburn import affine
@ -124,6 +125,13 @@ class RenderInfo(object):
# Deref genome # Deref genome
self.genome = self.db.genomes[self.genome] self.genome = self.db.genomes[self.genome]
for k, v in self.db.palettes.items():
pal = np.fromstring(base64.b64decode(v), np.uint8)
pal = np.reshape(pal, (256, 3))
pal_a = np.ones((256, 4), np.float32)
pal_a[:,:3] = pal / 255.0
self.db.palettes[k] = pal_a
class _AttrDict(dict): class _AttrDict(dict):
def __getattr__(self, name): def __getattr__(self, name):
return self[name] return self[name]

View File

@ -9,7 +9,6 @@ from ctypes import *
from cStringIO import StringIO from cStringIO import StringIO
import numpy as np import numpy as np
from scipy import ndimage from scipy import ndimage
import base64
from fr0stlib import pyflam3 from fr0stlib import pyflam3
from fr0stlib.pyflam3._flam3 import * from fr0stlib.pyflam3._flam3 import *
@ -142,19 +141,28 @@ class Renderer(object):
d_genome_knots = cuda.to_device(genome_knots) d_genome_knots = cuda.to_device(genome_knots)
info_size = 4 * len(self._iter.packer) * cps_per_block info_size = 4 * len(self._iter.packer) * cps_per_block
d_infos = cuda.mem_alloc(info_size) d_infos = cuda.mem_alloc(info_size)
pals = info.genome.color.palette
if isinstance(pals, basestring):
pals = [0.0, pals, 1.0, pals]
palint_times = np.empty(len(genome_times[0]), np.float32)
palint_times.fill(100.0)
palint_times[:len(pals)/2] = pals[::2]
d_palint_times = cuda.to_device(palint_times)
d_palint_vals = cuda.to_device(
np.concatenate(map(info.db.palettes.get, pals[1::2])))
d_palmem = cuda.mem_alloc(256 * info.palette_height * 4) d_palmem = cuda.mem_alloc(256 * info.palette_height * 4)
seeds = mwc.MWC.make_seeds(self._iter.NTHREADS * cps_per_block) seeds = mwc.MWC.make_seeds(self._iter.NTHREADS * cps_per_block)
d_seeds = cuda.to_device(seeds) d_seeds = cuda.to_device(seeds)
h_palmem = cuda.pagelocked_empty(
(info.palette_height, 256, 4), np.uint8)
h_out = cuda.pagelocked_empty((info.acc_height, info.acc_stride, 4), h_out = cuda.pagelocked_empty((info.acc_height, info.acc_stride, 4),
np.float32) np.float32)
filter_done_event = None filter_done_event = None
packer_fun = self.mod.get_function("interp_iter_params") packer_fun = self.mod.get_function("interp_iter_params")
palette_fun = self.mod.get_function("interp_palette_hsv")
iter_fun = self.mod.get_function("iter") iter_fun = self.mod.get_function("iter")
#iter_fun.set_cache_config(cuda.func_cache.PREFER_L1) #iter_fun.set_cache_config(cuda.func_cache.PREFER_L1)
@ -162,18 +170,18 @@ class Renderer(object):
last_time = times[0][0] last_time = times[0][0]
# TODO: move palette stuff to separate class; do interp
pal = np.fromstring(base64.b64decode(info.db.palettes.values()[0]),
np.uint8)
pal = np.reshape(pal, (256, 3))
h_palmem[0,:,:3] = pal
h_palmem[1:] = h_palmem[0]
for start, stop in times: for start, stop in times:
cen_cp = cuburn.genome.HacketyGenome(info.genome, (start+stop)/2) cen_cp = cuburn.genome.HacketyGenome(info.genome, (start+stop)/2)
# "Interp" already done above, but... if filter_done_event:
cuda.memcpy_htod_async(d_palmem, h_palmem, iter_stream) iter_stream.wait_for_event(filter_done_event)
width = np.float32((stop-start) / info.palette_height)
palette_fun(d_palmem, d_palint_times, d_palint_vals,
np.float32(start), width,
block=(256,1,1), grid=(info.palette_height,1),
stream=iter_stream)
tref = self.mod.get_texref('palTex') tref = self.mod.get_texref('palTex')
array_info = cuda.ArrayDescriptor() array_info = cuda.ArrayDescriptor()
array_info.height = info.palette_height array_info.height = info.palette_height
@ -186,10 +194,6 @@ class Renderer(object):
tref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES) tref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
tref.set_filter_mode(cuda.filter_mode.LINEAR) tref.set_filter_mode(cuda.filter_mode.LINEAR)
if filter_done_event:
iter_stream.wait_for_event(filter_done_event)
width = np.float32((stop-start) / cps_per_block) width = np.float32((stop-start) / cps_per_block)
packer_fun(d_infos, d_genome_times, d_genome_knots, packer_fun(d_infos, d_genome_times, d_genome_knots,
np.float32(start), width, d_seeds, np.float32(start), width, d_seeds,