mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Use much more accurate filtsum estimation polynomials
This commit is contained in:
		@ -124,6 +124,9 @@ void logscale(float4 *pixbuf, float4 *outbuf, float k1, float k2) {
 | 
				
			|||||||
    outbuf[i] = pix;
 | 
					    outbuf[i] = pix;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define MIN_SD 0.23299530
 | 
				
			||||||
 | 
					#define MAX_SD 4.33333333
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__global__
 | 
					__global__
 | 
				
			||||||
void density_est(float4 *pixbuf, float4 *outbuf, float *denbuf,
 | 
					void density_est(float4 *pixbuf, float4 *outbuf, float *denbuf,
 | 
				
			||||||
                 float est_sd, float neg_est_curve, float est_min,
 | 
					                 float est_sd, float neg_est_curve, float est_min,
 | 
				
			||||||
@ -147,61 +150,86 @@ void density_est(float4 *pixbuf, float4 *outbuf, float *denbuf,
 | 
				
			|||||||
            in.z *= ls;
 | 
					            in.z *= ls;
 | 
				
			||||||
            in.w *= ls;
 | 
					            in.w *= ls;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Base index of destination for writes
 | 
				
			||||||
 | 
					            int si = (threadIdx.y + W2) * FW + threadIdx.x + W2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // Calculate standard deviation of Gaussian kernel. The base SD is
 | 
					            // Calculate standard deviation of Gaussian kernel. The base SD is
 | 
				
			||||||
            // then scaled in inverse proportion to the density of the point
 | 
					            // then scaled in inverse proportion to the density of the point
 | 
				
			||||||
            // being scaled.
 | 
					            // being scaled.
 | 
				
			||||||
            float sd = est_sd * powf(den+1.0f, neg_est_curve);
 | 
					            float sd = est_sd * powf(den+1.0f, neg_est_curve);
 | 
				
			||||||
            // Clamp the final standard deviation. Things will go badly if the
 | 
					            // Clamp the final standard deviation. Things will go badly if the
 | 
				
			||||||
            // minimum is undershot.
 | 
					            // minimum is undershot.
 | 
				
			||||||
            sd = fmaxf(sd, est_min);
 | 
					            sd = fminf(MAX_SD, fmaxf(sd, est_min));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // This five-term polynomial approximates the sum of the filters
 | 
					            // Below a certain threshold, only one coeffecient would be
 | 
				
			||||||
            // with the clamping logic used here. See helpers/filt_err.py.
 | 
					            // retained anyway; we hop right to it.
 | 
				
			||||||
            float filtsum;
 | 
					            if (sd <= MIN_SD) {
 | 
				
			||||||
            filtsum = -0.20885075f  * sd +  0.90557721f;
 | 
					                de_add(si, 0,  0, in);
 | 
				
			||||||
            filtsum = filtsum       * sd +  5.28363054f;
 | 
					            } else {
 | 
				
			||||||
            filtsum = filtsum       * sd + -0.11733533f;
 | 
					                // These polynomials approximates the sum of the filters
 | 
				
			||||||
            filtsum = filtsum       * sd +  0.35670333f;
 | 
					                // with the clamping logic used here. See helpers/filt_err.py.
 | 
				
			||||||
            float filtscale = 1 / filtsum;
 | 
					                float filtsum;
 | 
				
			||||||
 | 
					                if (sd < 0.75) {
 | 
				
			||||||
 | 
					                    filtsum = -352.25061035f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +    1117.09680176f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +   -1372.48864746f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +     779.15478516f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +    -164.04229736f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +     -12.04892635f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +       9.04126644f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +       0.10304667f;
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    filtsum = -0.00403376f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +       0.06608720f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +      -0.38924992f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +       0.84797901f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +       0.34173131f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +      -4.67077589f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +      14.34595776f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +      -5.80082798f;
 | 
				
			||||||
 | 
					                    filtsum = filtsum * sd +       1.54098487f;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                float filtscale = 1.0f / filtsum;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // The reciprocal SD scaling coeffecient in the Gaussian exponent.
 | 
					                // The reciprocal SD scaling coeffecient in the Gaussian
 | 
				
			||||||
            // exp(-x^2/(2*sd^2)) = exp2f(x^2*rsd)
 | 
					                // exponent: exp(-x^2/(2*sd^2)) = exp2f(x^2*rsd)
 | 
				
			||||||
            float rsd = -0.5f * CUDART_L2E_F / (sd * sd);
 | 
					                float rsd = -0.5f * CUDART_L2E_F / (sd * sd);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int si = (threadIdx.y + W2) * FW + threadIdx.x + W2;
 | 
					                for (int jj = 0; jj <= W2; jj++) {
 | 
				
			||||||
            for (int jj = 0; jj <= W2; jj++) {
 | 
					                    float jj2f = jj;
 | 
				
			||||||
                float jj2f = jj;
 | 
					                    jj2f *= jj2f;
 | 
				
			||||||
                jj2f *= jj2f;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                float iif = 0;
 | 
					                    float iif = 0;
 | 
				
			||||||
                for (int ii = 0; ii <= jj; ii++) {
 | 
					                    for (int ii = 0; ii <= jj; ii++) {
 | 
				
			||||||
                    iif += 1;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    float coeff = exp2f((jj2f + iif * iif) * rsd) * filtscale;
 | 
					                        float coeff = exp2f((jj2f + iif * iif) * rsd)
 | 
				
			||||||
                    if (coeff < 0.0001f) break;
 | 
					                                    * filtscale;
 | 
				
			||||||
 | 
					                        if (coeff < 0.0001f) break;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    float4 scaled;
 | 
					                        float4 scaled;
 | 
				
			||||||
                    scaled.x = in.x * coeff;
 | 
					                        scaled.x = in.x * coeff;
 | 
				
			||||||
                    scaled.y = in.y * coeff;
 | 
					                        scaled.y = in.y * coeff;
 | 
				
			||||||
                    scaled.z = in.z * coeff;
 | 
					                        scaled.z = in.z * coeff;
 | 
				
			||||||
                    scaled.w = in.w * coeff;
 | 
					                        scaled.w = in.w * coeff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    de_add(si,  ii,  jj, scaled);
 | 
					                        de_add(si,  ii,  jj, scaled);
 | 
				
			||||||
                    if (jj == 0) continue;
 | 
					                        if (jj == 0) continue;
 | 
				
			||||||
                    de_add(si,  ii, -jj, scaled);
 | 
					                        de_add(si,  ii, -jj, scaled);
 | 
				
			||||||
                    if (ii != 0) {
 | 
					                        if (ii != 0) {
 | 
				
			||||||
                        de_add(si, -ii,  jj, scaled);
 | 
					                            de_add(si, -ii,  jj, scaled);
 | 
				
			||||||
                        de_add(si, -ii, -jj, scaled);
 | 
					                            de_add(si, -ii, -jj, scaled);
 | 
				
			||||||
                        if (ii == jj) continue;
 | 
					                            if (ii == jj) continue;
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                        de_add(si,  jj,  ii, scaled);
 | 
				
			||||||
 | 
					                        de_add(si, -jj,  ii, scaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        if (ii == 0) continue;
 | 
				
			||||||
 | 
					                        de_add(si, -jj, -ii, scaled);
 | 
				
			||||||
 | 
					                        de_add(si,  jj, -ii, scaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        iif += 1;
 | 
				
			||||||
 | 
					                        // TODO: validate that the above avoids bank conflicts
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    de_add(si,  jj,  ii, scaled);
 | 
					 | 
				
			||||||
                    de_add(si, -jj,  ii, scaled);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if (ii == 0) continue;
 | 
					 | 
				
			||||||
                    de_add(si, -jj, -ii, scaled);
 | 
					 | 
				
			||||||
                    de_add(si,  jj, -ii, scaled);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    // TODO: validate that the above avoids bank conflicts
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -262,7 +290,7 @@ void density_est(float4 *pixbuf, float4 *outbuf, float *denbuf,
 | 
				
			|||||||
            # (0.5/1.5)=1/3.
 | 
					            # (0.5/1.5)=1/3.
 | 
				
			||||||
            est_sd = np.float32(cp.estimator / 3.)
 | 
					            est_sd = np.float32(cp.estimator / 3.)
 | 
				
			||||||
            neg_est_curve = np.float32(-cp.estimator_curve)
 | 
					            neg_est_curve = np.float32(-cp.estimator_curve)
 | 
				
			||||||
            est_min = np.float32(max(cp.estimator_minimum / 3., 0.4))
 | 
					            est_min = np.float32(cp.estimator_minimum / 3.)
 | 
				
			||||||
            fun = mod.get_function("density_est")
 | 
					            fun = mod.get_function("density_est")
 | 
				
			||||||
            fun(abufd, obufd, dbufd, est_sd, neg_est_curve, est_min, k1, k2,
 | 
					            fun(abufd, obufd, dbufd, est_sd, neg_est_curve, est_min, k1, k2,
 | 
				
			||||||
                block=(32, 32, 1), grid=(self.features.acc_width/32, 1),
 | 
					                block=(32, 32, 1), grid=(self.features.acc_width/32, 1),
 | 
				
			||||||
 | 
				
			|||||||
@ -9,38 +9,96 @@ F2 = int(FWIDTH/2)
 | 
				
			|||||||
# The maximum size of any one coeffecient to be retained
 | 
					# The maximum size of any one coeffecient to be retained
 | 
				
			||||||
COEFF_EPS = 0.0001
 | 
					COEFF_EPS = 0.0001
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dists = np.fromfunction(lambda i, j: np.hypot(i-F2, j-F2), (FWIDTH, FWIDTH))
 | 
					dists2d = np.fromfunction(lambda i, j: np.hypot(i-F2, j-F2), (FWIDTH, FWIDTH))
 | 
				
			||||||
dists = dists.flatten()
 | 
					dists = dists2d.flatten()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# A flam3 estimator radius corresponds to a Gaussian filter with a standard
 | 
					# A flam3 estimator radius corresponds to a Gaussian filter with a standard
 | 
				
			||||||
# deviation of 1/3 the radius. We choose 13 as an arbitrary upper bound for the
 | 
					# deviation of 1/3 the radius. We choose 13 as an arbitrary upper bound for the
 | 
				
			||||||
# max filter radius. Larger radii will work without
 | 
					# max filter radius. The filter should reject larger radii.
 | 
				
			||||||
MAX_SD = 13 / 3.
 | 
					MAX_SD = 13 / 3.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# The minimum estimator radius is 1. In flam3, this is effectively no
 | 
					# The minimum estimator radius can be set as low as 0, but below a certain
 | 
				
			||||||
# filtering, but since the cutoff structure is defined by COEFF_EPS in cuburn,
 | 
					# radius only one coeffecient is retained. Since things get unstable near 0,
 | 
				
			||||||
# we undershoot it a bit to make the polyfit behave better at high densities.
 | 
					# we explicitly set a minimum threshold below which no coeffecients are
 | 
				
			||||||
MIN_SD = 0.3
 | 
					# retained.
 | 
				
			||||||
 | 
					MIN_SD = np.sqrt(-1 / (2 * np.log(COEFF_EPS)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sds = np.logspace(np.log10(MIN_SD), np.log10(MAX_SD), num=100)
 | 
					# Using two predicated three-term approximations is much more accurate than
 | 
				
			||||||
 | 
					# using a very large number of terms, due to nonlinear behavior at low SD.
 | 
				
			||||||
 | 
					# Everything above this SD uses one approximation; below, another.
 | 
				
			||||||
 | 
					SPLIT_SD = 0.75
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Calculate the filter sums at each coordinate
 | 
					# The lower endpoints are undershot by this proportion to reduce error
 | 
				
			||||||
sums = []
 | 
					UNDERSHOOT = 0.98
 | 
				
			||||||
for sd in sds:
 | 
					
 | 
				
			||||||
    coeffs = np.exp(dists**2 / (-2 * sd ** 2))
 | 
					sds_hi = np.linspace(SPLIT_SD * UNDERSHOOT, MAX_SD, num=1000)
 | 
				
			||||||
    sums.append(np.sum(filter(lambda v: v / np.sum(coeffs) > COEFF_EPS, coeffs)))
 | 
					sds_lo = np.linspace(MIN_SD * UNDERSHOOT, SPLIT_SD, num=1000)
 | 
				
			||||||
print sums
 | 
					
 | 
				
			||||||
 | 
					print 'At MIN_SD = %g, these are the coeffs:' % MIN_SD
 | 
				
			||||||
 | 
					print np.exp(dists2d**2 / (-2 * MIN_SD ** 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def eval_sds(sds, name, nterms):
 | 
				
			||||||
 | 
					    # Calculate the filter sums at each coordinate
 | 
				
			||||||
 | 
					    sums = []
 | 
				
			||||||
 | 
					    for sd in sds:
 | 
				
			||||||
 | 
					        coeffs = np.exp(dists**2 / (-2 * sd ** 2))
 | 
				
			||||||
 | 
					        # Note that this sum is the sum of all coordinates, though it should
 | 
				
			||||||
 | 
					        # actually be the result of the polynomial approximation. We could do
 | 
				
			||||||
 | 
					        # a feedback loop to improve accuracy, but I don't think the difference
 | 
				
			||||||
 | 
					        # is worth worrying about.
 | 
				
			||||||
 | 
					        sum = np.sum(coeffs)
 | 
				
			||||||
 | 
					        sums.append(np.sum(filter(lambda v: v / sum > COEFF_EPS, coeffs)))
 | 
				
			||||||
 | 
					    print 'Evaluating %s:' % name
 | 
				
			||||||
 | 
					    poly, resid, rank, sing, rcond = np.polyfit(sds, sums, nterms, full=True)
 | 
				
			||||||
 | 
					    print 'Fit for %s:' % name, poly, resid, rank, sing, rcond
 | 
				
			||||||
 | 
					    return sums, poly
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import matplotlib.pyplot as plt
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
poly, resid, rank, sing, rcond = np.polyfit(sds, sums, 4, full=True)
 | 
					sums_hi, poly_hi = eval_sds(sds_hi, 'hi', 8)
 | 
				
			||||||
print poly, resid, rank, sing, rcond
 | 
					sums_lo, poly_lo = eval_sds(sds_lo, 'lo', 7)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					num_undershoots = len(filter(lambda v: v < SPLIT_SD, sds_hi))
 | 
				
			||||||
 | 
					sds_hi = sds_hi[num_undershoots:]
 | 
				
			||||||
 | 
					sums_hi = sums_hi[num_undershoots:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					num_undershoots = len(filter(lambda v: v < MIN_SD, sds_lo))
 | 
				
			||||||
 | 
					sds_lo = sds_lo[num_undershoots:]
 | 
				
			||||||
 | 
					sums_lo = sums_lo[num_undershoots:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					polyf_hi = np.float32(poly_hi)
 | 
				
			||||||
 | 
					vals_hi = np.polyval(polyf_hi, sds_hi)
 | 
				
			||||||
 | 
					polyf_lo = np.float32(poly_lo)
 | 
				
			||||||
 | 
					vals_lo = np.polyval(polyf_lo, sds_lo)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def print_filt(filts):
 | 
				
			||||||
 | 
					    print '    filtsum = %4.8ff;' % filts[0]
 | 
				
			||||||
 | 
					    for f in filts[1:]:
 | 
				
			||||||
 | 
					        print '    filtsum = filtsum * sd + % 16.8ff;' % f
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print '\n\nFor your convenience:'
 | 
				
			||||||
 | 
					print '#define MIN_SD %.8f' % MIN_SD
 | 
				
			||||||
 | 
					print '#define MAX_SD %.8f' % MAX_SD
 | 
				
			||||||
 | 
					print 'if (sd < %g) {' % SPLIT_SD
 | 
				
			||||||
 | 
					print_filt(polyf_lo)
 | 
				
			||||||
 | 
					print '} else {'
 | 
				
			||||||
 | 
					print_filt(polyf_hi)
 | 
				
			||||||
 | 
					print '}'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sds = np.concatenate([sds_lo, sds_hi])
 | 
				
			||||||
 | 
					sums = np.concatenate([sums_lo, sums_hi])
 | 
				
			||||||
 | 
					vals = np.concatenate([vals_lo, vals_hi])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fig = plt.figure()
 | 
				
			||||||
 | 
					ax = fig.add_subplot(1,1,1)
 | 
				
			||||||
 | 
					ax.plot(sds, sums)
 | 
				
			||||||
 | 
					ax.plot(sds, vals)
 | 
				
			||||||
 | 
					ax.set_xlabel('stdev')
 | 
				
			||||||
 | 
					ax.set_ylabel('filter sum')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ax = ax.twinx()
 | 
				
			||||||
 | 
					ax.plot(sds, [abs((s-v)/v) for s, v in zip(sums, vals)])
 | 
				
			||||||
 | 
					ax.set_ylabel('rel err')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
polyf = np.float32(poly)
 | 
					 | 
				
			||||||
plt.plot(sds, sums)
 | 
					 | 
				
			||||||
plt.plot(sds, np.polyval(polyf, sds))
 | 
					 | 
				
			||||||
plt.show()
 | 
					plt.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
print np.polyval(poly, 1.1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# TODO: calculate error more fully, verify all this logic
 | 
					 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user