Flat (pre-packed int) palettes in deferred mode.

This commit is contained in:
Steven Robertson 2011-12-08 20:44:02 -05:00
parent b76208078f
commit d3ee6f36c2
4 changed files with 147 additions and 98 deletions

View File

@ -254,10 +254,34 @@ void interp_{{tname}}({{tname}}* out, float *times, float *knots,
{{endfor}}
}
__global__
void interp_palette_hsv_flat(mwc_st *rctxs,
const float *times, const float4 *sources,
float tstart, float tstep) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
mwc_st rctx = rctxs[gid];
float time = tstart + blockIdx.x * tstep;
float4 rgba = interp_color_hsv(times, sources, time);
// TODO: use YUV; pack Y at full precision, UV at quarter
uint2 out;
uint32_t r = min(255, (uint32_t) (rgba.x * 255.0f + 0.49f * mwc_next_11(rctx)));
uint32_t g = min(255, (uint32_t) (rgba.y * 255.0f + 0.49f * mwc_next_11(rctx)));
uint32_t b = min(255, (uint32_t) (rgba.z * 255.0f + 0.49f * mwc_next_11(rctx)));
out.x = (1 << 22) | (r << 4);
out.y = (g << 18) | b;
surf2Dwrite(out, flatpal, 8 * threadIdx.x, blockIdx.x);
rctxs[gid] = rctx;
}
""")
_decls = Template(r"""
surface<void, cudaSurfaceType2D> flatpal;
typedef struct {
{{for name in packed}}
float {{'_'.join(name)}};
@ -314,18 +338,15 @@ void test_cr(const float *times, const float *knots, const float *t, float *r) {
r[i] = catmull_rom(times, knots, t[i]);
}
__global__
void interp_palette_hsv(uchar4 *outs, const float *times, const float4 *sources,
float tstart, float tstep) {
float time = tstart + blockIdx.x * tstep;
__device__
float4 interp_color_hsv(const float *times, const float4 *sources, float time) {
int idx = fmaxf(bitwise_binsearch(times, time) + 1, 1);
float lf = (times[idx] - time) / (times[idx] - times[idx-1]);
float rf = 1.0f - lf;
float4 left = sources[blockDim.x * (idx - 1) + threadIdx.x];
float4 right = sources[blockDim.x * (idx) + threadIdx.x];
float lf = (times[idx] - time) / (times[idx] - times[idx-1]);
float rf = 1.0f - lf;
float3 lhsv = rgb2hsv(make_float3(left.x, left.y, left.z));
float3 rhsv = rgb2hsv(make_float3(right.x, right.y, right.z));
@ -346,13 +367,23 @@ void interp_palette_hsv(uchar4 *outs, const float *times, const float4 *sources,
hsv.x += 6.0f;
float3 rgb = hsv2rgb(hsv);
return make_float4(rgb.x, rgb.y, rgb.z, left.w * lf + right.w * rf);
}
__global__
void interp_palette_hsv(uchar4 *outs,
const float *times, const float4 *sources,
float tstart, float tstep) {
float time = tstart + blockIdx.x * tstep;
float4 rgba = interp_color_hsv(times, sources, time);
uchar4 out;
out.x = rgb.x * 255.0f;
out.y = rgb.y * 255.0f;
out.z = rgb.z * 255.0f;
out.w = 255.0f * (left.z * lf + right.z * rf);
out.x = rgba.x * 255.0f;
out.y = rgba.y * 255.0f;
out.z = rgba.z * 255.0f;
out.w = rgba.w * 255.0f;
outs[blockDim.x * blockIdx.x + threadIdx.x] = out;
}
""")

View File

@ -385,22 +385,6 @@ void iter(
#define SHAB 12
#define SHAW (1<<SHAB)
// Read from the shm accumulators and write to the global ones.
__device__
void write_shmem_helper(
float4 *acc,
const int glo_idx,
const uint32_t dr,
const uint32_t gb
) {
float4 pix = acc[glo_idx];
pix.x += (dr & 0xffff) / 127.0f;
pix.w += dr >> 16;
pix.y += (gb & 0xffff) / 127.0f;
pix.z += (gb >> 16) / 127.0f;
acc[glo_idx] = pix;
}
// Read the point log, accumulate in shared memory, and write the results.
// This kernel is to be launched with one block for every 4,096 addresses to
// be processed, and will handle those addresses.
@ -440,8 +424,10 @@ write_shmem(
s_acc_gb[tid+3*BS] = 0;
__syncthreads();
// This predicate is used for the horrible monkey-patching magic.
asm volatile(".reg .pred p; setp.lt.u32 p, %0, 42;" :: "r"(s_acc_dr[0]));
// This predicate is used for the horrible monkey-patching magic. Second
// variable is just to shut the compiler up.
asm volatile(".reg .pred p; setp.lt.u32 p, %0, 42;"
:: "r"(s_acc_dr[0]), "r"(s_acc_gb[0]));
// log_bounds[] holds inclusive prefix sums, so that log_bounds[0] is the
// largest index with radix 0, and so on.
@ -453,7 +439,7 @@ write_shmem(
else idx_lo = 0;
int idx_hi = (log_bounds[lb_idx_hi] & ~(BS - 1)) + BS;
float rnrounds = 1.0f / (idx_hi - idx_lo);
float rnrounds = {{'%d.0f' % info.palette_height}} / (idx_hi - idx_lo);
float time = tid * rnrounds;
float time_step = BS * rnrounds;
@ -468,65 +454,57 @@ write_shmem(
bfe_decl(glob_addr, entry, SHAB, 12);
if (glob_addr != bid) continue;
// Shared memory address, pre-shifted
int shr_addr;
asm("bfi.b32 %0, %1, 0, 2, 12;" : "=r"(shr_addr) : "r"(entry));
bfe_decl(color, entry, 24, 8);
float colorf = color / 255.0f;
float4 outcol = tex2D(palTex, colorf, time);
// TODO: change texture sampler to return shorts and avoid this
uint32_t r = 127.0f * outcol.x;
uint32_t g = 127.0f * outcol.y;
uint32_t b = 127.0f * outcol.z;
uint32_t dr = r + 0x10000, gb = g + (b << 16);
asm volatile ({{crep("""
{
.reg .pred q;
.reg .u32 d, r, g, b, dr, gb, drw, gbw, off;
.reg .u32 shoff, color, time, d, r, g, b, hi, lo, his, los, hiw, low;
.reg .u64 ptr;
.reg .f32 rf, gf, bf, df;
bfi.b32 shoff, %0, 0, 2, 12;
bfe.u32 color, %0, 24, 8;
shl.b32 color, color, 3;
cvt.rni.u32.f32 time, %1;
suld.b.2d.v2.b32.clamp {his, los}, [flatpal, {color, time}];
acc_write_start:
// This instruction will get replaced with a LDSLK that sets 'p'.
// The 0xffff is a signature to make sure we get the right instruction,
// and will get replaced with a 0-offset when patching.
ld.shared.volatile.u32 dr, [%2+0xffff];
@p ld.shared.volatile.u32 gb, [%2+0x4000];
@p add.u32 %0, dr, %0;
@p add.u32 %1, gb, %1;
// TODO: clever use of slct could remove an instruction?
@p setp.lo.u32 q, %0, 32768000;
@p selp.b32 drw, %0, 0, q;
@p selp.b32 gbw, %1, 0, q;
@p st.shared.volatile.u32 [%2+0x4000], gbw;
ld.shared.volatile.u32 low, [shoff+0xffff];
@p ld.shared.volatile.u32 hiw, [shoff+0x4000];
add.cc.u32 lo, los, low;
addc.u32 hi, his, hiw;
setp.lo.u32 q, hi, (1023 << 22);
selp.b32 hiw, hi, 0, q;
selp.b32 low, lo, 0, q;
@p st.shared.volatile.u32 [shoff+0x4000], hiw;
// This instruction will get replaced with an STSUL
@p st.shared.volatile.u32 [%2+0xffff], drw;
@p st.shared.volatile.u32 [shoff+0xffff], low;
@!p bra acc_write_start;
@q bra oflow_write_end;
shl.b32 %2, %2, 2;
cvt.u64.u32 ptr, %2;
add.u64 ptr, ptr, %3;
and.b32 r, %0, 0xffff;
shr.b32 g, %1, 16;
and.b32 b, %1, 0xffff;
shl.b32 shoff, shoff, 2;
cvt.u64.u32 ptr, shoff;
add.u64 ptr, ptr, %2;
bfe.u32 r, hi, 4, 18;
bfe.u32 g, lo, 18, 14;
bfi.b32 g, hi, g, 14, 4;
and.b32 b, lo, ((1<<18)-1);
cvt.rn.f32.u32 rf, r;
cvt.rn.f32.u32 gf, g;
cvt.rn.f32.u32 bf, b;
mul.ftz.f32 rf, rf, 0.007874015748031496;
mul.ftz.f32 gf, gf, 0.007874015748031496;
mul.ftz.f32 bf, bf, 0.007874015748031496;
red.add.f32 [ptr], 500.0;
red.add.f32 [ptr+4], bf;
red.add.f32 [ptr+8], gf;
red.add.f32 [ptr+12], rf;
mul.ftz.f32 rf, rf, (1.0/255.0);
mul.ftz.f32 gf, gf, (1.0/255.0);
mul.ftz.f32 bf, bf, (1.0/255.0);
red.add.f32 [ptr], rf;
red.add.f32 [ptr+4], gf;
red.add.f32 [ptr+8], bf;
red.add.f32 [ptr+12], 1023.0;
oflow_write_end:
}
""")}} :: "r"(dr), "r"(gb), "r"(shr_addr), "l"(glo_ptr));
""")}} :: "r"(entry), "f"(time), "l"(glo_ptr));
// TODO: go through the pain of manual address calculation for global ptr
time += time_step;
}
@ -534,7 +512,25 @@ oflow_write_end:
__syncthreads();
int idx = tid;
for (int i = 0; i < (SHAW / BS); i++) {
write_shmem_helper(acc, glo_base + idx, s_acc_dr[idx], s_acc_gb[idx]);
int d, r, g, b;
float4 pix = acc[glo_base + idx];
asm({{crep("""
{
.reg .u32 hi, lo;
ld.shared.u32 lo, [%4];
ld.shared.u32 hi, [%4+0x4000];
shr.u32 %0, hi, 22;
bfe.u32 %1, hi, 4, 18;
bfe.u32 %2, lo, 18, 14;
bfi.b32 %2, hi, %2, 14, 4;
and.b32 %3, lo, ((1<<18)-1);
}
""")}} : "=r"(d), "=r"(r), "=r"(g), "=r"(b) : "r"(idx*4));
pix.x += r / 255.0f;
pix.y += g / 255.0f;
pix.z += b / 255.0f;
pix.w += d;
acc[glo_base + idx] = pix;
idx += BS;
}
}

View File

@ -108,8 +108,10 @@ class RenderInfo(object):
# that palette-from-texture is enabled). For most genomes, this doesn't
# need to be very large at all. However, since only an easily-cached
# fraction of this will be accessed per SM, larger values shouldn't hurt
# performance too much. Power-of-two, please.
palette_height = 16
# performance too much. When using deferred accumulation, increasing this
# value increases the number of uniquely-dithered samples, which is nice.
# Power-of-two, please.
palette_height = 64
# Maximum width of DE and other spatial filters, and thus in turn the
# amount of padding applied. Note that, for now, this must not be changed!

View File

@ -113,7 +113,6 @@ class Renderer(object):
reset_rb_fun = self.mod.get_function("reset_rb")
packer_fun = self.mod.get_function("interp_iter_params")
palette_fun = self.mod.get_function("interp_palette_hsv")
iter_fun = self.mod.get_function("iter")
write_fun = self.mod.get_function("write_shmem")
@ -134,7 +133,8 @@ class Renderer(object):
event_b = None
nbins = info.acc_height * info.acc_stride
d_accum = cuda.mem_alloc(16 * nbins)
# Extra padding in accum helps with write_shmem overruns
d_accum = cuda.mem_alloc(16 * nbins + (1<<16))
d_out = cuda.mem_alloc(16 * nbins)
acc_size = np.array([info.acc_width, info.acc_height, info.acc_stride])
@ -168,7 +168,8 @@ class Renderer(object):
reset_rb_fun(np.int32(rb_size), block=(1,1,1))
d_points = cuda.mem_alloc(nslots * 16)
seeds = mwc.MWC.make_seeds(nslots)
# We may add extra seeds to simplify palette dithering.
seeds = mwc.MWC.make_seeds(max(nslots, 256 * info.palette_height))
d_seeds = cuda.to_device(seeds)
# We used to auto-calculate this to a multiple of the number of SMs on
@ -191,13 +192,34 @@ class Renderer(object):
d_palint_times = cuda.to_device(palint_times)
d_palint_vals = cuda.to_device(
np.concatenate(map(info.db.palettes.get, pals[1::2])))
if info.acc_mode == 'deferred':
palette_fun = self.mod.get_function("interp_palette_hsv_flat")
dsc = cuda.ArrayDescriptor3D()
dsc.height = info.palette_height
dsc.width = 256
dsc.depth = 0
dsc.format = cuda.array_format.SIGNED_INT32
dsc.num_channels = 2
dsc.flags = cuda.array3d_flags.SURFACE_LDST
palarray = cuda.Array(dsc)
tref = self.mod.get_surfref('flatpal')
tref.set_array(palarray, 0)
else:
palette_fun = self.mod.get_function("interp_palette_hsv")
dsc = cuda.ArrayDescriptor()
dsc.height = info.palette_height
dsc.width = 256
dsc.format = cuda.array_format.UNSIGNED_INT8
dsc.num_channels = 4
d_palmem = cuda.mem_alloc(256 * info.palette_height * 4)
pal_array_info = cuda.ArrayDescriptor()
pal_array_info.height = info.palette_height
pal_array_info.width = 256
pal_array_info.array_format = cuda.array_format.UNSIGNED_INT8
pal_array_info.num_channels = 4
tref = self.mod.get_texref('palTex')
tref.set_address_2d(d_palmem, dsc, 1024)
tref.set_format(cuda.array_format.UNSIGNED_INT8, 4)
tref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
tref.set_filter_mode(cuda.filter_mode.LINEAR)
h_out_a = cuda.pagelocked_empty((info.acc_height, info.acc_stride, 4),
np.float32)
@ -206,18 +228,17 @@ class Renderer(object):
last_idx = None
for idx, start, stop in times:
width = np.float32((stop-start) / info.palette_height)
palette_fun(d_palmem, d_palint_times, d_palint_vals,
np.float32(start), width,
twidth = np.float32((stop-start) / info.palette_height)
if info.acc_mode == 'deferred':
palette_fun(d_seeds, d_palint_times, d_palint_vals,
np.float32(start), twidth,
block=(256,1,1), grid=(info.palette_height,1),
stream=write_stream)
else:
palette_fun(d_palmem, d_palint_times, d_palint_vals,
np.float32(start), twidth,
block=(256,1,1), grid=(info.palette_height,1),
stream=write_stream)
# TODO: do we need to do this each time in order to reset cache?
tref = self.mod.get_texref('palTex')
tref.set_address_2d(d_palmem, pal_array_info, 1024)
tref.set_format(cuda.array_format.UNSIGNED_INT8, 4)
tref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
tref.set_filter_mode(cuda.filter_mode.LINEAR)
width = np.float32((stop-start) / ntemporal_samples)
packer_fun(d_infos, d_genome_times, d_genome_knots,
@ -251,12 +272,11 @@ class Renderer(object):
_sync_stream(iter_stream, write_stream)
write_fun(d_accum, d_log_sorted, sorter.dglobal,
block=(1024, 1, 1), grid=(nwriteblocks, 1),
texrefs=[tref], stream=write_stream)
stream=write_stream)
else:
iter_fun(np.uint64(d_accum), d_seeds, d_points, d_infos,
block=(32, self._iter.NTHREADS/32, 1),
grid=(ntemporal_samples, nrounds),
texrefs=[tref], stream=iter_stream)
grid=(ntemporal_samples, nrounds), stream=iter_stream)
util.BaseCode.fill_dptr(self.mod, d_out, 4 * nbins, filt_stream)
_sync_stream(filt_stream, write_stream)