mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 03:30:05 -05:00
Add SplEval.insert_knot()
This commit is contained in:
parent
de56383a61
commit
808bd66138
@ -14,6 +14,10 @@ class SplEval(object):
|
||||
_deriv = np.matrix(np.diag([3,2,1], 1))
|
||||
|
||||
def __init__(self, knots, v0=None, v1=None):
|
||||
self.knots = self.normalize(knots, v0, v1)
|
||||
|
||||
@staticmethod
|
||||
def normalize(knots, v0=None, v1=None):
|
||||
if isinstance(knots, (int, float)):
|
||||
knots = [0.0, knots, 1.0, knots]
|
||||
elif not np.all(np.diff(np.float32(np.asarray(knots))[::2]) > 0):
|
||||
@ -31,8 +35,9 @@ class SplEval(object):
|
||||
v1 = (knots[-1] - knots[-3]) / float(knots[-2] - knots[-4])
|
||||
knots.extend([3, knots[-3] + (3 - knots[-4]) * v1])
|
||||
|
||||
self.knots = np.zeros((2, len(knots)/2))
|
||||
self.knots.T.flat[:] = knots
|
||||
knotarray = np.zeros((2, len(knots)/2))
|
||||
knotarray.T.flat[:] = knots
|
||||
return knotarray
|
||||
|
||||
def find_knots(self, itime):
|
||||
idx = np.searchsorted(self.knots[0], itime) - 2
|
||||
@ -89,6 +94,10 @@ class SplEval(object):
|
||||
return list(self.knots.T.flat)[2:-2]
|
||||
return list(self.knots.T.flat)
|
||||
|
||||
def insert_knot(self, t, v):
|
||||
knots = list(sum(sorted(zip(*self.knots) + [(t,v)]), ()))
|
||||
self.knots = self.normalize(knots, self(0, 1), self(1, 1))
|
||||
|
||||
def palette_decode(datastrs):
|
||||
"""
|
||||
Decode a palette (stored as a list suitable for JSON packing) into a
|
||||
|
Loading…
Reference in New Issue
Block a user