mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Use functions for palette instead of silly objects
This commit is contained in:
parent
3b29bb2dc2
commit
529bf48982
@ -80,15 +80,28 @@ class SplEval(object):
|
|||||||
return self.knots[1][0]
|
return self.knots[1][0]
|
||||||
return list(self.knots.T.flat)
|
return list(self.knots.T.flat)
|
||||||
|
|
||||||
class Palette(object):
|
def palette_decode(datastrs):
|
||||||
def __init__(self, datastrs):
|
"""
|
||||||
if datastrs[0] != 'rgb8':
|
Decode a palette (stored as a list suitable for JSON packing) into a
|
||||||
raise NotImplementedError
|
palette. Internal palette format is simply as a (256,4) array of [0,1]
|
||||||
self.width = 256
|
RGBA floats.
|
||||||
raw = base64.b64decode(''.join(datastrs[1:]))
|
"""
|
||||||
pal = np.reshape(np.fromstring(raw, np.uint8), (256, 3))
|
if datastrs[0] != 'rgb8':
|
||||||
self.data = np.ones((256, 4), np.float32)
|
raise NotImplementedError
|
||||||
self.data[:,:3] = pal / 255.0
|
raw = base64.b64decode(''.join(datastrs[1:]))
|
||||||
|
pal = np.reshape(np.fromstring(raw, np.uint8), (256, 3))
|
||||||
|
data = np.ones((256, 4), np.float32)
|
||||||
|
data[:,:3] = pal / 255.0
|
||||||
|
return data
|
||||||
|
|
||||||
|
def palette_encode(data, format='rgb8'):
|
||||||
|
"""
|
||||||
|
Encode an internal-format palette to an external representation.
|
||||||
|
"""
|
||||||
|
if format != 'rgb8':
|
||||||
|
raise NotImplementedError
|
||||||
|
enc = base64.b64encode(np.uint8(data*255.0))
|
||||||
|
return ['rgb8'] + [enc[i:i+64] for i in range(0, len(enc), 64)]
|
||||||
|
|
||||||
class _AttrDict(dict):
|
class _AttrDict(dict):
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
@ -126,7 +139,7 @@ class Genome(_AttrDict):
|
|||||||
_AttrDict._wrap(v)
|
_AttrDict._wrap(v)
|
||||||
self[k] = v
|
self[k] = v
|
||||||
|
|
||||||
self.decoded_palettes = map(Palette, self.palettes)
|
self.decoded_palettes = map(palette_decode, self.palettes)
|
||||||
pal = self.color.palette_times
|
pal = self.color.palette_times
|
||||||
if isinstance(pal, basestring):
|
if isinstance(pal, basestring):
|
||||||
self.palette_times = [(0.0, int(pal)), (1.0, int(pal))]
|
self.palette_times = [(0.0, int(pal)), (1.0, int(pal))]
|
||||||
|
@ -251,7 +251,7 @@ class Renderer(object):
|
|||||||
palint_times.fill(1e10)
|
palint_times.fill(1e10)
|
||||||
palint_times[:len(ptimes)] = ptimes
|
palint_times[:len(ptimes)] = ptimes
|
||||||
d_palint_times = cuda.to_device(palint_times)
|
d_palint_times = cuda.to_device(palint_times)
|
||||||
pvals = [genome.decoded_palettes[i].data for i in pidxs]
|
pvals = [genome.decoded_palettes[i] for i in pidxs]
|
||||||
d_palint_vals = cuda.to_device(np.concatenate(pvals))
|
d_palint_vals = cuda.to_device(np.concatenate(pvals))
|
||||||
|
|
||||||
if self.acc_mode in ('deferred', 'atomic'):
|
if self.acc_mode in ('deferred', 'atomic'):
|
||||||
|
Loading…
Reference in New Issue
Block a user