2011-05-05 23:37:18 -04:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
# The maximum number of coeffecients that will ever be retained on the device
|
2011-12-10 16:24:49 -05:00
|
|
|
FWIDTH = 15
|
2011-05-05 23:37:18 -04:00
|
|
|
|
|
|
|
# The number of points on either side of the center in one dimension
|
|
|
|
F2 = int(FWIDTH/2)
|
|
|
|
|
|
|
|
# The maximum size of any one coeffecient to be retained
|
|
|
|
COEFF_EPS = 0.0001
|
|
|
|
|
2011-06-12 17:37:57 -04:00
|
|
|
dists2d = np.fromfunction(lambda i, j: np.hypot(i-F2, j-F2), (FWIDTH, FWIDTH))
|
|
|
|
dists = dists2d.flatten()
|
2011-05-05 23:37:18 -04:00
|
|
|
|
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
# This translates to a cap on DE filter radius of 50. Even this fits very
|
|
|
|
# comfortably within the chosen COEFF_EPS.
|
|
|
|
MAX_SCALE = -3/25.
|
2011-05-05 23:37:18 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
# When the scale is above this value, we'd be directly clamping to one bin
|
|
|
|
MIN_SCALE = np.log(0.0001)
|
2011-05-05 23:37:18 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
# Everything above this scale uses one approximation; below, another.
|
|
|
|
SPLIT_SCALE = -1.1
|
2011-06-12 17:37:57 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
# The upper endpoints are overshot by this proportion to reduce error
|
|
|
|
OVERSHOOT = 1.01
|
2011-06-12 17:37:57 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
# No longer 'scale'-related, but we call it that anyway
|
|
|
|
scales_hi = np.linspace(SPLIT_SCALE, MAX_SCALE * OVERSHOOT, num=1000)
|
|
|
|
scales_lo = np.linspace(MIN_SCALE, SPLIT_SCALE * OVERSHOOT, num=1000)
|
2011-06-12 17:37:57 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
def eval_scales(scales, name, nterms):
|
2011-06-12 17:37:57 -04:00
|
|
|
# Calculate the filter sums at each coordinate
|
|
|
|
sums = []
|
2011-12-10 16:24:49 -05:00
|
|
|
for scale in scales:
|
|
|
|
coeffs = np.exp(dists**2 * scale)
|
2011-06-12 17:37:57 -04:00
|
|
|
# Note that this sum is the sum of all coordinates, though it should
|
|
|
|
# actually be the result of the polynomial approximation. We could do
|
|
|
|
# a feedback loop to improve accuracy, but I don't think the difference
|
|
|
|
# is worth worrying about.
|
|
|
|
sum = np.sum(coeffs)
|
2011-12-10 16:24:49 -05:00
|
|
|
sums.append(1./np.sum(filter(lambda v: v / sum > COEFF_EPS, coeffs)))
|
2011-06-12 17:37:57 -04:00
|
|
|
print 'Evaluating %s:' % name
|
2011-12-10 16:24:49 -05:00
|
|
|
poly, resid, rank, sing, rcond = np.polyfit(scales, sums, nterms, full=True)
|
2011-06-12 17:37:57 -04:00
|
|
|
print 'Fit for %s:' % name, poly, resid, rank, sing, rcond
|
|
|
|
return sums, poly
|
2011-05-05 23:37:18 -04:00
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
sums_hi, poly_hi = eval_scales(scales_hi, 'hi', 7)
|
|
|
|
sums_lo, poly_lo = eval_scales(scales_lo, 'lo', 7)
|
2011-05-05 23:37:18 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
num_overshoots = len(filter(lambda v: v > MAX_SCALE, scales_hi))
|
|
|
|
scales_hi = scales_hi[num_overshoots:]
|
|
|
|
sums_hi = sums_hi[num_overshoots:]
|
2011-06-12 17:37:57 -04:00
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
num_overshoots = len(filter(lambda v: v > SPLIT_SCALE, scales_lo))
|
|
|
|
scales_lo = scales_lo[num_overshoots:]
|
|
|
|
sums_lo = sums_lo[num_overshoots:]
|
2011-06-12 17:37:57 -04:00
|
|
|
|
|
|
|
polyf_hi = np.float32(poly_hi)
|
2011-12-10 16:24:49 -05:00
|
|
|
vals_hi = np.polyval(polyf_hi, scales_hi)
|
2011-06-12 17:37:57 -04:00
|
|
|
polyf_lo = np.float32(poly_lo)
|
2011-12-10 16:24:49 -05:00
|
|
|
vals_lo = np.polyval(polyf_lo, scales_lo)
|
2011-05-05 23:37:18 -04:00
|
|
|
|
2011-06-12 17:37:57 -04:00
|
|
|
def print_filt(filts):
|
2011-12-10 16:24:49 -05:00
|
|
|
print ' filtsum = %4.8ef;' % filts[0]
|
2011-06-12 17:37:57 -04:00
|
|
|
for f in filts[1:]:
|
2011-12-10 16:24:49 -05:00
|
|
|
print ' filtsum = filtsum * scale + % 16.8ef;' % f
|
2011-06-12 17:37:57 -04:00
|
|
|
|
|
|
|
print '\n\nFor your convenience:'
|
2011-12-10 16:24:49 -05:00
|
|
|
print '#define MIN_SCALE %.8gf' % MIN_SCALE
|
|
|
|
print '#define MAX_SCALE %.8gf' % MAX_SCALE
|
|
|
|
print 'if (scale < %gf) {' % SPLIT_SCALE
|
2011-06-12 17:37:57 -04:00
|
|
|
print_filt(polyf_lo)
|
|
|
|
print '} else {'
|
|
|
|
print_filt(polyf_hi)
|
|
|
|
print '}'
|
|
|
|
|
2011-12-10 16:24:49 -05:00
|
|
|
scales = np.concatenate([scales_lo, scales_hi])
|
2011-06-12 17:37:57 -04:00
|
|
|
sums = np.concatenate([sums_lo, sums_hi])
|
|
|
|
vals = np.concatenate([vals_lo, vals_hi])
|
|
|
|
|
|
|
|
fig = plt.figure()
|
|
|
|
ax = fig.add_subplot(1,1,1)
|
2011-12-10 16:24:49 -05:00
|
|
|
ax.plot(scales, sums)
|
|
|
|
ax.plot(scales, vals)
|
2011-06-12 17:37:57 -04:00
|
|
|
ax.set_xlabel('stdev')
|
|
|
|
ax.set_ylabel('filter sum')
|
|
|
|
|
|
|
|
ax = ax.twinx()
|
2011-12-10 16:24:49 -05:00
|
|
|
ax.plot(scales, [abs((s-v)/v) for s, v in zip(sums, vals)])
|
2011-06-12 17:37:57 -04:00
|
|
|
ax.set_ylabel('rel err')
|
|
|
|
|
|
|
|
plt.show()
|
2011-05-05 23:37:18 -04:00
|
|
|
|