Add new palette modes; use 'yuv' by default.

This commit is contained in:
Steven Robertson 2011-12-23 09:50:03 -05:00
parent 693a7a6dc3
commit de56383a61
3 changed files with 143 additions and 38 deletions

View File

@ -250,34 +250,9 @@ void interp_{{tname}}(
{{endfor}}
}
__global__
void interp_palette_hsv_flat(mwc_st *rctxs,
const float *times, const float4 *sources,
float tstart, float tstep) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
mwc_st rctx = rctxs[gid];
float time = tstart + blockIdx.x * tstep;
float4 rgba = interp_color_hsv(times, sources, time);
// TODO: use YUV; pack Y at full precision, UV at quarter
uint2 out;
uint32_t r = min(255, (uint32_t) (rgba.x * 255.0f + 0.49f * mwc_next_11(rctx)));
uint32_t g = min(255, (uint32_t) (rgba.y * 255.0f + 0.49f * mwc_next_11(rctx)));
uint32_t b = min(255, (uint32_t) (rgba.z * 255.0f + 0.49f * mwc_next_11(rctx)));
out.y = (1 << 22) | (r << 4);
out.x = (g << 18) | b;
surf2Dwrite(out, flatpal, 8 * threadIdx.x, blockIdx.x);
rctxs[gid] = rctx;
}
""")
_decls = Template(r"""
surface<void, cudaSurfaceType2D> flatpal;
typedef struct {
{{for name in packed}}
float {{'_'.join(name)}};
@ -334,18 +309,83 @@ void test_cr(const float *times, const float *knots, const float *t, float *r) {
int i = threadIdx.x + blockDim.x * blockIdx.x;
r[i] = catmull_rom(times, knots, t[i]);
}
""")
class Palette(HunkOCode):
# The JPEG YUV full-range matrix, without bias into the positve regime.
# This assumes input color space is CIERGB D65, encoded with gamma 2.2.
# Note that some interpolated colors may exceed the sRGB and YUV gamuts.
YUV = np.matrix([[ 0.299, 0.587, 0.114],
[-0.168736, -0.331264, 0.5],
[ 0.5, -0.418688, -0.081312]])
def __init__(self, interp_mode):
assert interp_mode in self.modes
self.mode = interp_mode
self.defs = self._defs.substitute(mode=interp_mode)
def prepare(self, palettes):
"""
Produce palettes suitable for uploading to the device. Returns an
array of palettes in the same size and shape as the input.
This function will never modify its argument, but may return it
unmodified for certain interpolation modes.
"""
if self.mode == 'yuvpolar':
ys, uvrs, uvts, alphas = zip(*map(self.rgbtoyuvpolar, palettes))
# Center all medians as closely to 0 as possible
means = np.mean(uvts, axis=1)
newmeans = (means + np.pi) % (2 * np.pi) - np.pi
uvts = (newmeans - means).reshape((-1, 1)) + uvts
zipped = zip(ys, uvrs, uvts, alphas)
return np.array(zipped, dtype='f4').transpose((0, 2, 1))
return palettes
@classmethod
def rgbtoyuvpolar(cls, pal):
# TODO: premultiply alpha or some nonsense like that?
y, u, v = np.array(cls.YUV * pal.T[:3])
uvr = np.hypot(u, v)
uvt = np.arctan2(v, u)
cls.monotonify(uvt)
return y, uvr, uvt, pal.T[3]
@classmethod
def yuvpolartorgb(cls, y, uvr, uvt, a):
u = uvr * np.cos(uvt)
v = uvr * np.sin(uvt)
r, g, b = np.array(cls.YUV.I * np.array([y, u, v]))
# Ensure Fortran order so that the memory gets laid out correctly
return np.array([r, g, b, a], order='F').T
@staticmethod
def monotonify(uvt):
"""Eliminate sign-flips in an array of radian angles (in-place)."""
diff = np.diff(uvt)
for i in np.nonzero(np.abs(diff) > np.pi)[0]:
uvt[i:] -= np.sign(diff[i]) * 2 * np.pi
modes = ['hsv', 'yuv', 'yuvpolar']
decls = "surface<void, cudaSurfaceType2D> flatpal;\n"
_defs = Template(r"""
__device__
float4 interp_color_hsv(const float *times, const float4 *sources, float time) {
float4 interp_color(const float *times, const float4 *sources, float time) {
int idx = fmaxf(bitwise_binsearch(times, time) + 1, 1);
float lf = (times[idx] - time) / (times[idx] - times[idx-1]);
float rf = 1.0f - lf;
float4 left = sources[blockDim.x * (idx - 1) + threadIdx.x];
float4 right = sources[blockDim.x * (idx) + threadIdx.x];
float3 rgb;
float3 lhsv = rgb2hsv(make_float3(left.x, left.y, left.z));
float3 rhsv = rgb2hsv(make_float3(right.x, right.y, right.z));
float3 l3 = make_float3(left.x, left.y, left.z);
float3 r3 = make_float3(right.x, right.y, right.z);
{{if mode == 'hsv'}}
float3 lhsv = rgb2hsv(l3);
float3 rhsv = rgb2hsv(r3);
if (fabs(lhsv.x - rhsv.x) > 3.0f)
if (lhsv.x < rhsv.x)
@ -363,16 +403,36 @@ float4 interp_color_hsv(const float *times, const float4 *sources, float time) {
if (hsv.x < 0.0f)
hsv.x += 6.0f;
float3 rgb = hsv2rgb(hsv);
rgb = hsv2rgb(hsv);
{{elif mode.startswith('yuv')}}
float3 yuv;
{{if mode == 'yuv'}}
float3 lyuv = rgb2yuv(l3);
float3 ryuv = rgb2yuv(r3);
yuv.x = lyuv.x * lf + ryuv.x * rf;
yuv.y = lyuv.y * lf + ryuv.y * rf;
yuv.z = lyuv.z * lf + ryuv.z * rf;
{{elif mode == 'yuvpolar'}}
yuv.x = l3.x * lf + r3.x * rf;
float radius = l3.y * lf + r3.y * rf;
float angle = l3.z * lf + r3.z * rf;
yuv.y = radius * cosf(angle);
yuv.z = radius * sinf(angle);
{{endif}}
rgb = yuv2rgb(yuv);
{{endif}}
return make_float4(rgb.x, rgb.y, rgb.z, left.w * lf + right.w * rf);
}
__global__
void interp_palette_hsv(uchar4 *outs,
void interp_palette(uchar4 *outs,
const float *times, const float4 *sources,
float tstart, float tstep) {
float time = tstart + blockIdx.x * tstep;
float4 rgba = interp_color_hsv(times, sources, time);
float4 rgba = interp_color(times, sources, time);
uchar4 out;
out.x = rgba.x * 255.0f;
@ -382,5 +442,27 @@ void interp_palette_hsv(uchar4 *outs,
outs[blockDim.x * blockIdx.x + threadIdx.x] = out;
}
__global__
void interp_palette_flat(mwc_st *rctxs,
const float *times, const float4 *sources,
float tstart, float tstep) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
mwc_st rctx = rctxs[gid];
float time = tstart + blockIdx.x * tstep;
float4 rgba = interp_color(times, sources, time);
// TODO: use YUV; pack Y at full precision, UV at quarter
uint2 out;
uint32_t r = min(255, (uint32_t) (rgba.x * 255.0f + 0.49f * mwc_next_11(rctx)));
uint32_t g = min(255, (uint32_t) (rgba.y * 255.0f + 0.49f * mwc_next_11(rctx)));
uint32_t b = min(255, (uint32_t) (rgba.z * 255.0f + 0.49f * mwc_next_11(rctx)));
out.y = (1 << 22) | (r << 4);
out.x = (g << 18) | b;
surf2Dwrite(out, flatpal, 8 * threadIdx.x, blockIdx.x);
rctxs[gid] = rctx;
}
""")

View File

@ -148,6 +148,25 @@ void write_half(float &xy, float x, float y, float den) {
: "=f"(xy) : "f"(x), "f"(y), "f"(den));
}
/* This conversion uses the JPEG full-range standard, though it does *not* add
* an offset to UV to bias them into the positive regime. */
__device__
float3 rgb2yuv(float3 rgb) {
return make_float3(
0.299f * rgb.x + 0.587f * rgb.y + 0.114f * rgb.z,
-0.168736f * rgb.x - 0.331264f * rgb.y + 0.5f * rgb.z,
0.5f * rgb.x - 0.418688f * rgb.y - 0.081312f * rgb.z);
}
__device__
float3 yuv2rgb(float3 yuv) {
return make_float3(
yuv.x + 1.402f * yuv.z,
yuv.x - 0.34414f * yuv.y - 0.71414f * yuv.z,
yuv.x + 1.772f * yuv.y);
}
__device__
float3 rgb2hsv(float3 rgb) {
float M = fmaxf(fmaxf(rgb.x, rgb.y), rgb.z);

View File

@ -17,7 +17,7 @@ import pycuda.tools
import cuburn.genome
from cuburn import affine
from cuburn.code import util, mwc, iter, filtering, sort
from cuburn.code import util, mwc, iter, interp, filtering, sort
RenderedImage = namedtuple('RenderedImage', 'buf idx gpu_time')
Dimensions = namedtuple('Dimensions', 'w h aw ah astride')
@ -48,6 +48,9 @@ class Renderer(object):
# pre-dithered surfaces.
palette_height = 64
# Palette color interpolation mode (see code.interp.Palette)
palette_interp_mode = 'yuv'
# Maximum width of DE and other spatial filters, and thus in turn the
# amount of padding applied. Note that, for now, this must not be changed!
# The filtering code makes deep assumptions about this value.
@ -63,7 +66,7 @@ class Renderer(object):
keep = False
def __init__(self):
self._iter = self.src = self.cubin = self.mod = None
self._iter = self.pal = self.src = self.cubin = self.mod = None
# Ensure class options don't get contaminated on an instance
self.cmp_options = list(self.cmp_options)
@ -85,8 +88,9 @@ class Renderer(object):
self._iter = iter.IterCode(self, genome)
self._iter.packer.finalize()
self.pal = interp.Palette(self.palette_interp_mode)
self.src = util.assemble_code(util.BaseCode, mwc.MWC, self._iter.packer,
self._iter)
self.pal, self._iter)
with open(os.path.join(tempfile.gettempdir(), 'kernel.cu'), 'w') as fp:
fp.write(self.src)
self.cubin = pycuda.compiler.compile(
@ -251,11 +255,11 @@ class Renderer(object):
palint_times.fill(1e10)
palint_times[:len(ptimes)] = ptimes
d_palint_times = cuda.to_device(palint_times)
pvals = [genome.decoded_palettes[i] for i in pidxs]
pvals = self.pal.prepare([genome.decoded_palettes[i] for i in pidxs])
d_palint_vals = cuda.to_device(np.concatenate(pvals))
if self.acc_mode in ('deferred', 'atomic'):
palette_fun = self.mod.get_function("interp_palette_hsv_flat")
palette_fun = self.mod.get_function("interp_palette_flat")
dsc = argset(cuda.ArrayDescriptor3D(), height=self.palette_height,
width=256, depth=0, format=cuda.array_format.SIGNED_INT32,
num_channels=2, flags=cuda.array3d_flags.SURFACE_LDST)
@ -264,7 +268,7 @@ class Renderer(object):
tref = self.mod.get_surfref('flatpal')
tref.set_array(palarray, 0)
else:
palette_fun = self.mod.get_function("interp_palette_hsv")
palette_fun = self.mod.get_function("interp_palette")
dsc = argset(cuda.ArrayDescriptor(), height=self.palette_height,
width=256, format=cuda.array_format.UNSIGNED_INT8,
num_channels=4)