diff --git a/cuburn/code/sort.py b/cuburn/code/sort.py index b0536f2..b24547c 100644 --- a/cuburn/code/sort.py +++ b/cuburn/code/sort.py @@ -35,7 +35,8 @@ void prefix_scan( int *offsets, int *pfxs, const unsigned int *keys, - const int lo_bit + const int lo_bit, + const int ignore_max ) { const int tid = threadIdx.x; __shared__ int shr_pfxs[RDXSZ]; @@ -56,9 +57,11 @@ void prefix_scan( // TODO: load 2 at once, compute, use a BFI to pack the two offsets // into an int to halve storage / bandwidth int key = keys[idx]; - int radix; - get_radix(radix, key, lo_bit); - offsets[idx] = atomicAdd(shr_pfxs + radix, 1); + if (!ignore_max || key != -1) { + int radix; + get_radix(radix, key, lo_bit); + offsets[idx] = atomicAdd(shr_pfxs + radix, 1); + } idx += BLKSZ; } @@ -272,13 +275,14 @@ void prefix_sum_inner( __syncthreads(); sum = 0; - // Intentionally exclusive indexing here + // Intentionally exclusive indexing here, fixed below for (int i = 0; i < tid; i++) sum += sums[i]; + glob_pfxs[tid] = sum + sums[tid]; __syncthreads(); - sums[tid] = glob_pfxs[tid] = sum; - idx = tid; + sums[tid] = sum; + idx = tid; for (int i = 0; i < ncondensed; i++) { int c = condensed[idx]; condensed[idx] = sum; @@ -336,7 +340,8 @@ void radix_sort( const int *offsets, const int *pfxs, const int *locals, - const int lo_bit + const int lo_bit, + const int ignore_max ) { const int tid = threadIdx.x; const int blk_offset = GRPSZ * blockIdx.x; @@ -347,8 +352,12 @@ void radix_sort( if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i]; __syncthreads(); + if (ignore_max) + for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1; + for (int i = tid; i < GRPSZ; i += BLKSZ) { int key = keys[i+blk_offset]; + if (ignore_max && key == -1) continue; int radix; get_radix(radix, key, lo_bit); int offset = offsets[i+blk_offset] + shr_offs[radix]; @@ -363,6 +372,7 @@ void radix_sort( #pragma unroll for (int j = 0; j < GRP_BLK_FACTOR; j++) { int key = defer[i]; + if (ignore_max && key == -1) continue; int radix; get_radix(radix, key, lo_bit); int offset = shr_offs[radix] + i; @@ -430,12 +440,17 @@ class Sorter(object): RuntimeWarning, stacklevel=3) self.warn_issued = True - def sort(self, dst, src, size, lo_bit=0, + def sort(self, dst, src, size, lo_bit=0, ignore_max=False, prev_lo_bit=None, prev_bits=None, stream=None): """ Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is the LSB. Store the result in 'dst'. + If 'ignore_max' is True, any key with the value 0xffffffff will be + effeciently discarded. The number of valid results in the final array + can be determined by examining the last item in the device array + pointed to by this class's 'dglobal' property. + GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even with the multi-pass repair kernel. This means that multi-pass sort is bugged. Don't use it unless your application can handle non-monotonic @@ -452,7 +467,8 @@ class Sorter(object): grids = size / self.group_size self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit), - block=(512, 1, 1), grid=(grids, 1), stream=stream) + np.int32(ignore_max), block=(512, 1, 1), + grid=(grids, 1), stream=stream) # Intentionally ignore prev_bits=0 if prev_bits: @@ -483,8 +499,8 @@ class Sorter(object): self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth, block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream) - self.radix_sort(dst, src, - self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit), + self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals, + np.int32(lo_bit), np.int32(ignore_max), block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream) def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,