mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Fix spline derivative calculation.
This commit is contained in:
parent
bfff915b7e
commit
6204f36ebc
@ -39,33 +39,34 @@ class SplEval(object):
|
||||
# Normalize to [0,1]
|
||||
t = itime - times[1]
|
||||
times = times - times[1]
|
||||
t = t / times[2]
|
||||
times = times / times[2]
|
||||
return times, vals, t
|
||||
scale = 1 / times[2]
|
||||
t = t * scale
|
||||
times = times * scale
|
||||
return times, vals, t, scale
|
||||
|
||||
def __call__(self, itime, deriv=0):
|
||||
times, vals, t = self.find_knots(itime)
|
||||
times, vals, t, scale = self.find_knots(itime)
|
||||
|
||||
m1 = (vals[2] - vals[0]) / (1.0 - times[0])
|
||||
m2 = (vals[3] - vals[1]) / times[3]
|
||||
|
||||
mat = self._mat
|
||||
if deriv:
|
||||
mat = mat * self._deriv ** (deriv)
|
||||
mat = mat * (scale * self._deriv) ** deriv
|
||||
val = [m1, vals[1], m2, vals[2]] * mat * np.array([[t**3, t**2, t, 1]]).T
|
||||
return val[0,0]
|
||||
|
||||
def _plt(self, name='SplEval'):
|
||||
def _plt(self, name='SplEval', fig=111, show=True):
|
||||
import matplotlib.pyplot as plt
|
||||
x = np.linspace(-0.05, 1.05, 500)
|
||||
x = np.linspace(-0.0, 1.0, 500)
|
||||
r = x[1] - x[0]
|
||||
derivs = [(self(i+2*r)-self(i-2*r))/(4*r) for i in x]
|
||||
plt.figure(1)
|
||||
plt.figure(fig)
|
||||
plt.title(name)
|
||||
plt.plot(x,map(self,x),x,[self(i,1) for i in x],'--',x,derivs,'r.')
|
||||
plt.show()
|
||||
|
||||
|
||||
plt.plot(x,map(self,x),x,[self(i,1) for i in x],'--',
|
||||
self.knots[0],self.knots[1],'x')
|
||||
plt.xlim(0.0, 1.0)
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
def __str__(self):
|
||||
return '[%g:%g]' % (self(0), self(1))
|
||||
|
Loading…
Reference in New Issue
Block a user