mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-04-21 00:51:31 -04:00
Fix velocity matching
This commit is contained in:
parent
09725ba794
commit
22c1ec872c
@ -15,7 +15,7 @@ class SplEval(object):
|
|||||||
|
|
||||||
def __init__(self, knots, v0=None, v1=None):
|
def __init__(self, knots, v0=None, v1=None):
|
||||||
if isinstance(knots, (int, float)):
|
if isinstance(knots, (int, float)):
|
||||||
knots = [-0.1, knots, 0.0, knots, 1.0, knots, 1.1, knots]
|
knots = [0.0, knots, 1.0, knots]
|
||||||
elif not np.all(np.diff(np.float32(np.asarray(knots))[::2]) > 0):
|
elif not np.all(np.diff(np.float32(np.asarray(knots))[::2]) > 0):
|
||||||
raise ValueError("Spline times are non-monotonic. (Use "
|
raise ValueError("Spline times are non-monotonic. (Use "
|
||||||
"nextafterf()-spaced times to anchor tangents.)")
|
"nextafterf()-spaced times to anchor tangents.)")
|
||||||
@ -24,11 +24,11 @@ class SplEval(object):
|
|||||||
# [0,1] interval, add them.
|
# [0,1] interval, add them.
|
||||||
if knots[0] >= 0:
|
if knots[0] >= 0:
|
||||||
if v0 is None:
|
if v0 is None:
|
||||||
v0 = (knots[3] - knots[1]) / (knots[2] - knots[0])
|
v0 = (knots[3] - knots[1]) / float(knots[2] - knots[0])
|
||||||
knots = [-2, knots[3] + (knots[2] - 2) * v0] + knots
|
knots = [-2, knots[3] - (knots[2] + 2) * v0] + knots
|
||||||
if knots[-2] <= 1:
|
if knots[-2] <= 1:
|
||||||
if v1 is None:
|
if v1 is None:
|
||||||
v1 = (knots[-1] - knots[-3]) / (knots[-2] - knots[-4])
|
v1 = (knots[-1] - knots[-3]) / float(knots[-2] - knots[-4])
|
||||||
knots.extend([3, knots[-3] + (3 - knots[-4]) * v1])
|
knots.extend([3, knots[-3] + (3 - knots[-4]) * v1])
|
||||||
|
|
||||||
self.knots = np.zeros((2, len(knots)/2))
|
self.knots = np.zeros((2, len(knots)/2))
|
||||||
@ -79,8 +79,14 @@ class SplEval(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def knotlist(self):
|
def knotlist(self):
|
||||||
|
# TODO: scale error constants proportional to RMS?
|
||||||
|
# If everything is constant, return a constant
|
||||||
if np.std(self.knots[1]) < 1e-6:
|
if np.std(self.knots[1]) < 1e-6:
|
||||||
return self.knots[1][0]
|
return self.knots[1][0]
|
||||||
|
# If constant slope, omit the end knots
|
||||||
|
slopes = np.diff(self.knots[1]) / np.diff(self.knots[0])
|
||||||
|
if np.std(slopes) < 1e-6:
|
||||||
|
return list(self.knots.T.flat)[2:-2]
|
||||||
return list(self.knots.T.flat)
|
return list(self.knots.T.flat)
|
||||||
|
|
||||||
def palette_decode(datastrs):
|
def palette_decode(datastrs):
|
||||||
|
Loading…
Reference in New Issue
Block a user