Add derivative support to SplWrap.

This commit is contained in:
Steven Robertson 2011-10-28 18:51:33 -04:00
parent a2c4c90cb2
commit 28e73d08ee

View File

@ -1,10 +1,12 @@
import json
import base64
import numpy as np
import scipy.interpolate
from cuburn import affine
class SplEval(object):
_mat = np.matrix([[1.,-2, 1, 0], [2,-3, 0, 1],
[1,-1, 0, 0], [-2, 3, 0, 0]])
_deriv = np.matrix(np.diag([3,2,1], 1))
def __init__(self, knots):
# If a single value is passed, normalize to a constant set of knots
if isinstance(knots, (int, float)):
@ -28,13 +30,9 @@ class SplEval(object):
self.knots = np.zeros((2, len(knots)/2))
self.knots.T.flat[:] = knots
def __call__(self, itime):
try:
return np.asarray(map(self, itime))
except:
pass
idx = np.searchsorted(self.knots[0], itime) - 2
idx = max(0, min(len(self.knots[0]) - 4, idx))
def find_knots(self, itime):
idx = np.searchsorted(self.knots[0][1:-1], itime)
idx = min(idx, len(self.knots[0]) - 4)
times = self.knots[0][idx:idx+4]
vals = self.knots[1][idx:idx+4]
@ -43,20 +41,19 @@ class SplEval(object):
times = times - times[1]
t = t / times[2]
times = times / times[2]
return times, vals, t
return self._interp(times, vals, t)
@staticmethod
def _interp(times, vals, t):
t2 = t * t
t3 = t * t2
def __call__(self, itime, deriv=0):
times, vals, t = self.find_knots(itime)
m1 = (vals[2] - vals[0]) / (1.0 - times[0])
m2 = (vals[3] - vals[1]) / times[3]
r = ( m1 * (t3 - 2*t2 + t) + vals[1] * (2*t3 - 3*t2 + 1)
+ m2 * (t3 - t2) + vals[2] * (-2*t3 + 3*t2) )
return r
mat = self._mat
if deriv:
mat *= self._deriv ** (deriv+1)
val = [m1, vals[1], m2, vals[2]] * mat * np.array([[t**3, t**2, t, 1]]).T
return val[0,0]
def __str__(self):
return '[%g:%g]' % (self(0), self(1))