Add -1-skipping to sort.

This commit is contained in:
Steven Robertson 2011-11-11 17:34:43 -05:00
parent 54f411878b
commit 05e1d08681

View File

@ -35,7 +35,8 @@ void prefix_scan(
int *offsets, int *offsets,
int *pfxs, int *pfxs,
const unsigned int *keys, const unsigned int *keys,
const int lo_bit const int lo_bit,
const int ignore_max
) { ) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
__shared__ int shr_pfxs[RDXSZ]; __shared__ int shr_pfxs[RDXSZ];
@ -56,9 +57,11 @@ void prefix_scan(
// TODO: load 2 at once, compute, use a BFI to pack the two offsets // TODO: load 2 at once, compute, use a BFI to pack the two offsets
// into an int to halve storage / bandwidth // into an int to halve storage / bandwidth
int key = keys[idx]; int key = keys[idx];
if (!ignore_max || key != -1) {
int radix; int radix;
get_radix(radix, key, lo_bit); get_radix(radix, key, lo_bit);
offsets[idx] = atomicAdd(shr_pfxs + radix, 1); offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
}
idx += BLKSZ; idx += BLKSZ;
} }
@ -272,13 +275,14 @@ void prefix_sum_inner(
__syncthreads(); __syncthreads();
sum = 0; sum = 0;
// Intentionally exclusive indexing here // Intentionally exclusive indexing here, fixed below
for (int i = 0; i < tid; i++) sum += sums[i]; for (int i = 0; i < tid; i++) sum += sums[i];
glob_pfxs[tid] = sum + sums[tid];
__syncthreads(); __syncthreads();
sums[tid] = glob_pfxs[tid] = sum; sums[tid] = sum;
idx = tid;
idx = tid;
for (int i = 0; i < ncondensed; i++) { for (int i = 0; i < ncondensed; i++) {
int c = condensed[idx]; int c = condensed[idx];
condensed[idx] = sum; condensed[idx] = sum;
@ -336,7 +340,8 @@ void radix_sort(
const int *offsets, const int *offsets,
const int *pfxs, const int *pfxs,
const int *locals, const int *locals,
const int lo_bit const int lo_bit,
const int ignore_max
) { ) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int blk_offset = GRPSZ * blockIdx.x; const int blk_offset = GRPSZ * blockIdx.x;
@ -347,8 +352,12 @@ void radix_sort(
if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i]; if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i];
__syncthreads(); __syncthreads();
if (ignore_max)
for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1;
for (int i = tid; i < GRPSZ; i += BLKSZ) { for (int i = tid; i < GRPSZ; i += BLKSZ) {
int key = keys[i+blk_offset]; int key = keys[i+blk_offset];
if (ignore_max && key == -1) continue;
int radix; int radix;
get_radix(radix, key, lo_bit); get_radix(radix, key, lo_bit);
int offset = offsets[i+blk_offset] + shr_offs[radix]; int offset = offsets[i+blk_offset] + shr_offs[radix];
@ -363,6 +372,7 @@ void radix_sort(
#pragma unroll #pragma unroll
for (int j = 0; j < GRP_BLK_FACTOR; j++) { for (int j = 0; j < GRP_BLK_FACTOR; j++) {
int key = defer[i]; int key = defer[i];
if (ignore_max && key == -1) continue;
int radix; int radix;
get_radix(radix, key, lo_bit); get_radix(radix, key, lo_bit);
int offset = shr_offs[radix] + i; int offset = shr_offs[radix] + i;
@ -430,12 +440,17 @@ class Sorter(object):
RuntimeWarning, stacklevel=3) RuntimeWarning, stacklevel=3)
self.warn_issued = True self.warn_issued = True
def sort(self, dst, src, size, lo_bit=0, def sort(self, dst, src, size, lo_bit=0, ignore_max=False,
prev_lo_bit=None, prev_bits=None, stream=None): prev_lo_bit=None, prev_bits=None, stream=None):
""" """
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
the LSB. Store the result in 'dst'. the LSB. Store the result in 'dst'.
If 'ignore_max' is True, any key with the value 0xffffffff will be
effeciently discarded. The number of valid results in the final array
can be determined by examining the last item in the device array
pointed to by this class's 'dglobal' property.
GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
with the multi-pass repair kernel. This means that multi-pass sort is with the multi-pass repair kernel. This means that multi-pass sort is
bugged. Don't use it unless your application can handle non-monotonic bugged. Don't use it unless your application can handle non-monotonic
@ -452,7 +467,8 @@ class Sorter(object):
grids = size / self.group_size grids = size / self.group_size
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit), self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
block=(512, 1, 1), grid=(grids, 1), stream=stream) np.int32(ignore_max), block=(512, 1, 1),
grid=(grids, 1), stream=stream)
# Intentionally ignore prev_bits=0 # Intentionally ignore prev_bits=0
if prev_bits: if prev_bits:
@ -483,8 +499,8 @@ class Sorter(object):
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth, self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream) block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
self.radix_sort(dst, src, self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals,
self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit), np.int32(lo_bit), np.int32(ignore_max),
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream) block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0, def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,