Add -1-skipping to sort.

This commit is contained in:
Steven Robertson 2011-11-11 17:34:43 -05:00
parent 54f411878b
commit 05e1d08681

View File

@ -35,7 +35,8 @@ void prefix_scan(
int *offsets,
int *pfxs,
const unsigned int *keys,
const int lo_bit
const int lo_bit,
const int ignore_max
) {
const int tid = threadIdx.x;
__shared__ int shr_pfxs[RDXSZ];
@ -56,9 +57,11 @@ void prefix_scan(
// TODO: load 2 at once, compute, use a BFI to pack the two offsets
// into an int to halve storage / bandwidth
int key = keys[idx];
if (!ignore_max || key != -1) {
int radix;
get_radix(radix, key, lo_bit);
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
}
idx += BLKSZ;
}
@ -272,13 +275,14 @@ void prefix_sum_inner(
__syncthreads();
sum = 0;
// Intentionally exclusive indexing here
// Intentionally exclusive indexing here, fixed below
for (int i = 0; i < tid; i++) sum += sums[i];
glob_pfxs[tid] = sum + sums[tid];
__syncthreads();
sums[tid] = glob_pfxs[tid] = sum;
idx = tid;
sums[tid] = sum;
idx = tid;
for (int i = 0; i < ncondensed; i++) {
int c = condensed[idx];
condensed[idx] = sum;
@ -336,7 +340,8 @@ void radix_sort(
const int *offsets,
const int *pfxs,
const int *locals,
const int lo_bit
const int lo_bit,
const int ignore_max
) {
const int tid = threadIdx.x;
const int blk_offset = GRPSZ * blockIdx.x;
@ -347,8 +352,12 @@ void radix_sort(
if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i];
__syncthreads();
if (ignore_max)
for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1;
for (int i = tid; i < GRPSZ; i += BLKSZ) {
int key = keys[i+blk_offset];
if (ignore_max && key == -1) continue;
int radix;
get_radix(radix, key, lo_bit);
int offset = offsets[i+blk_offset] + shr_offs[radix];
@ -363,6 +372,7 @@ void radix_sort(
#pragma unroll
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
int key = defer[i];
if (ignore_max && key == -1) continue;
int radix;
get_radix(radix, key, lo_bit);
int offset = shr_offs[radix] + i;
@ -430,12 +440,17 @@ class Sorter(object):
RuntimeWarning, stacklevel=3)
self.warn_issued = True
def sort(self, dst, src, size, lo_bit=0,
def sort(self, dst, src, size, lo_bit=0, ignore_max=False,
prev_lo_bit=None, prev_bits=None, stream=None):
"""
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
the LSB. Store the result in 'dst'.
If 'ignore_max' is True, any key with the value 0xffffffff will be
effeciently discarded. The number of valid results in the final array
can be determined by examining the last item in the device array
pointed to by this class's 'dglobal' property.
GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
with the multi-pass repair kernel. This means that multi-pass sort is
bugged. Don't use it unless your application can handle non-monotonic
@ -452,7 +467,8 @@ class Sorter(object):
grids = size / self.group_size
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
block=(512, 1, 1), grid=(grids, 1), stream=stream)
np.int32(ignore_max), block=(512, 1, 1),
grid=(grids, 1), stream=stream)
# Intentionally ignore prev_bits=0
if prev_bits:
@ -483,8 +499,8 @@ class Sorter(object):
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
self.radix_sort(dst, src,
self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals,
np.int32(lo_bit), np.int32(ignore_max),
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,