diff --git a/cuburn/genome.py b/cuburn/genome.py index 6c4d687..5411ca2 100644 --- a/cuburn/genome.py +++ b/cuburn/genome.py @@ -31,8 +31,8 @@ class SplEval(object): self.knots.T.flat[:] = knots def find_knots(self, itime): - idx = np.searchsorted(self.knots[0][1:-1], itime) - idx = min(idx, len(self.knots[0]) - 4) + idx = np.searchsorted(self.knots[0], itime) - 2 + idx = max(0, min(idx, len(self.knots[0]) - 4)) times = self.knots[0][idx:idx+4] vals = self.knots[1][idx:idx+4] @@ -51,15 +51,33 @@ class SplEval(object): mat = self._mat if deriv: - mat *= self._deriv ** (deriv+1) + mat = mat * self._deriv ** (deriv) val = [m1, vals[1], m2, vals[2]] * mat * np.array([[t**3, t**2, t, 1]]).T return val[0,0] + def _plt(self, name='SplEval'): + import matplotlib.pyplot as plt + x = np.linspace(-0.05, 1.05, 500) + r = x[1] - x[0] + derivs = [(self(i+2*r)-self(i-2*r))/(4*r) for i in x] + plt.figure(1) + plt.title(name) + plt.plot(x,map(self,x),x,[self(i,1) for i in x],'--',x,derivs,'r.') + plt.show() + + + def __str__(self): return '[%g:%g]' % (self(0), self(1)) def __repr__(self): return '' % (self(0), self(1)) + @property + def knotlist(self): + if np.std(self.knots[1]) < 1e-6: + return self.knots[1][0] + return list(self.knots.T.flat) + @classmethod def wrap(cls, obj): """