From 808bd66138b1eede6c15c8b3ef91925329c00439 Mon Sep 17 00:00:00 2001 From: Steven Robertson Date: Wed, 28 Dec 2011 15:39:17 -0500 Subject: [PATCH] Add SplEval.insert_knot() --- cuburn/genome.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cuburn/genome.py b/cuburn/genome.py index 3280c60..d734b0d 100644 --- a/cuburn/genome.py +++ b/cuburn/genome.py @@ -14,6 +14,10 @@ class SplEval(object): _deriv = np.matrix(np.diag([3,2,1], 1)) def __init__(self, knots, v0=None, v1=None): + self.knots = self.normalize(knots, v0, v1) + + @staticmethod + def normalize(knots, v0=None, v1=None): if isinstance(knots, (int, float)): knots = [0.0, knots, 1.0, knots] elif not np.all(np.diff(np.float32(np.asarray(knots))[::2]) > 0): @@ -31,8 +35,9 @@ class SplEval(object): v1 = (knots[-1] - knots[-3]) / float(knots[-2] - knots[-4]) knots.extend([3, knots[-3] + (3 - knots[-4]) * v1]) - self.knots = np.zeros((2, len(knots)/2)) - self.knots.T.flat[:] = knots + knotarray = np.zeros((2, len(knots)/2)) + knotarray.T.flat[:] = knots + return knotarray def find_knots(self, itime): idx = np.searchsorted(self.knots[0], itime) - 2 @@ -89,6 +94,10 @@ class SplEval(object): return list(self.knots.T.flat)[2:-2] return list(self.knots.T.flat) + def insert_knot(self, t, v): + knots = list(sum(sorted(zip(*self.knots) + [(t,v)]), ())) + self.knots = self.normalize(knots, self(0, 1), self(1, 1)) + def palette_decode(datastrs): """ Decode a palette (stored as a list suitable for JSON packing) into a