mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 03:30:05 -05:00
Use functions for palette instead of silly objects
This commit is contained in:
parent
3b29bb2dc2
commit
529bf48982
@ -80,15 +80,28 @@ class SplEval(object):
|
||||
return self.knots[1][0]
|
||||
return list(self.knots.T.flat)
|
||||
|
||||
class Palette(object):
|
||||
def __init__(self, datastrs):
|
||||
if datastrs[0] != 'rgb8':
|
||||
raise NotImplementedError
|
||||
self.width = 256
|
||||
raw = base64.b64decode(''.join(datastrs[1:]))
|
||||
pal = np.reshape(np.fromstring(raw, np.uint8), (256, 3))
|
||||
self.data = np.ones((256, 4), np.float32)
|
||||
self.data[:,:3] = pal / 255.0
|
||||
def palette_decode(datastrs):
|
||||
"""
|
||||
Decode a palette (stored as a list suitable for JSON packing) into a
|
||||
palette. Internal palette format is simply as a (256,4) array of [0,1]
|
||||
RGBA floats.
|
||||
"""
|
||||
if datastrs[0] != 'rgb8':
|
||||
raise NotImplementedError
|
||||
raw = base64.b64decode(''.join(datastrs[1:]))
|
||||
pal = np.reshape(np.fromstring(raw, np.uint8), (256, 3))
|
||||
data = np.ones((256, 4), np.float32)
|
||||
data[:,:3] = pal / 255.0
|
||||
return data
|
||||
|
||||
def palette_encode(data, format='rgb8'):
|
||||
"""
|
||||
Encode an internal-format palette to an external representation.
|
||||
"""
|
||||
if format != 'rgb8':
|
||||
raise NotImplementedError
|
||||
enc = base64.b64encode(np.uint8(data*255.0))
|
||||
return ['rgb8'] + [enc[i:i+64] for i in range(0, len(enc), 64)]
|
||||
|
||||
class _AttrDict(dict):
|
||||
def __getattr__(self, name):
|
||||
@ -126,7 +139,7 @@ class Genome(_AttrDict):
|
||||
_AttrDict._wrap(v)
|
||||
self[k] = v
|
||||
|
||||
self.decoded_palettes = map(Palette, self.palettes)
|
||||
self.decoded_palettes = map(palette_decode, self.palettes)
|
||||
pal = self.color.palette_times
|
||||
if isinstance(pal, basestring):
|
||||
self.palette_times = [(0.0, int(pal)), (1.0, int(pal))]
|
||||
|
@ -251,7 +251,7 @@ class Renderer(object):
|
||||
palint_times.fill(1e10)
|
||||
palint_times[:len(ptimes)] = ptimes
|
||||
d_palint_times = cuda.to_device(palint_times)
|
||||
pvals = [genome.decoded_palettes[i].data for i in pidxs]
|
||||
pvals = [genome.decoded_palettes[i] for i in pidxs]
|
||||
d_palint_vals = cuda.to_device(np.concatenate(pvals))
|
||||
|
||||
if self.acc_mode in ('deferred', 'atomic'):
|
||||
|
Loading…
Reference in New Issue
Block a user