mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Add derivative support to SplWrap.
This commit is contained in:
parent
a2c4c90cb2
commit
28e73d08ee
@ -1,10 +1,12 @@
|
||||
import json
|
||||
import base64
|
||||
import numpy as np
|
||||
import scipy.interpolate
|
||||
from cuburn import affine
|
||||
|
||||
class SplEval(object):
|
||||
_mat = np.matrix([[1.,-2, 1, 0], [2,-3, 0, 1],
|
||||
[1,-1, 0, 0], [-2, 3, 0, 0]])
|
||||
_deriv = np.matrix(np.diag([3,2,1], 1))
|
||||
|
||||
def __init__(self, knots):
|
||||
# If a single value is passed, normalize to a constant set of knots
|
||||
if isinstance(knots, (int, float)):
|
||||
@ -28,13 +30,9 @@ class SplEval(object):
|
||||
self.knots = np.zeros((2, len(knots)/2))
|
||||
self.knots.T.flat[:] = knots
|
||||
|
||||
def __call__(self, itime):
|
||||
try:
|
||||
return np.asarray(map(self, itime))
|
||||
except:
|
||||
pass
|
||||
idx = np.searchsorted(self.knots[0], itime) - 2
|
||||
idx = max(0, min(len(self.knots[0]) - 4, idx))
|
||||
def find_knots(self, itime):
|
||||
idx = np.searchsorted(self.knots[0][1:-1], itime)
|
||||
idx = min(idx, len(self.knots[0]) - 4)
|
||||
|
||||
times = self.knots[0][idx:idx+4]
|
||||
vals = self.knots[1][idx:idx+4]
|
||||
@ -43,20 +41,19 @@ class SplEval(object):
|
||||
times = times - times[1]
|
||||
t = t / times[2]
|
||||
times = times / times[2]
|
||||
return times, vals, t
|
||||
|
||||
return self._interp(times, vals, t)
|
||||
|
||||
@staticmethod
|
||||
def _interp(times, vals, t):
|
||||
t2 = t * t
|
||||
t3 = t * t2
|
||||
def __call__(self, itime, deriv=0):
|
||||
times, vals, t = self.find_knots(itime)
|
||||
|
||||
m1 = (vals[2] - vals[0]) / (1.0 - times[0])
|
||||
m2 = (vals[3] - vals[1]) / times[3]
|
||||
|
||||
r = ( m1 * (t3 - 2*t2 + t) + vals[1] * (2*t3 - 3*t2 + 1)
|
||||
+ m2 * (t3 - t2) + vals[2] * (-2*t3 + 3*t2) )
|
||||
return r
|
||||
mat = self._mat
|
||||
if deriv:
|
||||
mat *= self._deriv ** (deriv+1)
|
||||
val = [m1, vals[1], m2, vals[2]] * mat * np.array([[t**3, t**2, t, 1]]).T
|
||||
return val[0,0]
|
||||
|
||||
def __str__(self):
|
||||
return '[%g:%g]' % (self(0), self(1))
|
||||
|
Loading…
Reference in New Issue
Block a user