Add SplEval.insert_knot()

This commit is contained in:
Steven Robertson 2011-12-28 15:39:17 -05:00
parent de56383a61
commit 808bd66138

View File

@ -14,6 +14,10 @@ class SplEval(object):
_deriv = np.matrix(np.diag([3,2,1], 1)) _deriv = np.matrix(np.diag([3,2,1], 1))
def __init__(self, knots, v0=None, v1=None): def __init__(self, knots, v0=None, v1=None):
self.knots = self.normalize(knots, v0, v1)
@staticmethod
def normalize(knots, v0=None, v1=None):
if isinstance(knots, (int, float)): if isinstance(knots, (int, float)):
knots = [0.0, knots, 1.0, knots] knots = [0.0, knots, 1.0, knots]
elif not np.all(np.diff(np.float32(np.asarray(knots))[::2]) > 0): elif not np.all(np.diff(np.float32(np.asarray(knots))[::2]) > 0):
@ -31,8 +35,9 @@ class SplEval(object):
v1 = (knots[-1] - knots[-3]) / float(knots[-2] - knots[-4]) v1 = (knots[-1] - knots[-3]) / float(knots[-2] - knots[-4])
knots.extend([3, knots[-3] + (3 - knots[-4]) * v1]) knots.extend([3, knots[-3] + (3 - knots[-4]) * v1])
self.knots = np.zeros((2, len(knots)/2)) knotarray = np.zeros((2, len(knots)/2))
self.knots.T.flat[:] = knots knotarray.T.flat[:] = knots
return knotarray
def find_knots(self, itime): def find_knots(self, itime):
idx = np.searchsorted(self.knots[0], itime) - 2 idx = np.searchsorted(self.knots[0], itime) - 2
@ -89,6 +94,10 @@ class SplEval(object):
return list(self.knots.T.flat)[2:-2] return list(self.knots.T.flat)[2:-2]
return list(self.knots.T.flat) return list(self.knots.T.flat)
def insert_knot(self, t, v):
knots = list(sum(sorted(zip(*self.knots) + [(t,v)]), ()))
self.knots = self.normalize(knots, self(0, 1), self(1, 1))
def palette_decode(datastrs): def palette_decode(datastrs):
""" """
Decode a palette (stored as a list suitable for JSON packing) into a Decode a palette (stored as a list suitable for JSON packing) into a