mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Add -1-skipping to sort.
This commit is contained in:
parent
54f411878b
commit
05e1d08681
@ -35,7 +35,8 @@ void prefix_scan(
|
||||
int *offsets,
|
||||
int *pfxs,
|
||||
const unsigned int *keys,
|
||||
const int lo_bit
|
||||
const int lo_bit,
|
||||
const int ignore_max
|
||||
) {
|
||||
const int tid = threadIdx.x;
|
||||
__shared__ int shr_pfxs[RDXSZ];
|
||||
@ -56,9 +57,11 @@ void prefix_scan(
|
||||
// TODO: load 2 at once, compute, use a BFI to pack the two offsets
|
||||
// into an int to halve storage / bandwidth
|
||||
int key = keys[idx];
|
||||
int radix;
|
||||
get_radix(radix, key, lo_bit);
|
||||
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
|
||||
if (!ignore_max || key != -1) {
|
||||
int radix;
|
||||
get_radix(radix, key, lo_bit);
|
||||
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
|
||||
}
|
||||
idx += BLKSZ;
|
||||
}
|
||||
|
||||
@ -272,13 +275,14 @@ void prefix_sum_inner(
|
||||
__syncthreads();
|
||||
sum = 0;
|
||||
|
||||
// Intentionally exclusive indexing here
|
||||
// Intentionally exclusive indexing here, fixed below
|
||||
for (int i = 0; i < tid; i++) sum += sums[i];
|
||||
glob_pfxs[tid] = sum + sums[tid];
|
||||
__syncthreads();
|
||||
|
||||
sums[tid] = glob_pfxs[tid] = sum;
|
||||
idx = tid;
|
||||
sums[tid] = sum;
|
||||
|
||||
idx = tid;
|
||||
for (int i = 0; i < ncondensed; i++) {
|
||||
int c = condensed[idx];
|
||||
condensed[idx] = sum;
|
||||
@ -336,7 +340,8 @@ void radix_sort(
|
||||
const int *offsets,
|
||||
const int *pfxs,
|
||||
const int *locals,
|
||||
const int lo_bit
|
||||
const int lo_bit,
|
||||
const int ignore_max
|
||||
) {
|
||||
const int tid = threadIdx.x;
|
||||
const int blk_offset = GRPSZ * blockIdx.x;
|
||||
@ -347,8 +352,12 @@ void radix_sort(
|
||||
if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i];
|
||||
__syncthreads();
|
||||
|
||||
if (ignore_max)
|
||||
for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1;
|
||||
|
||||
for (int i = tid; i < GRPSZ; i += BLKSZ) {
|
||||
int key = keys[i+blk_offset];
|
||||
if (ignore_max && key == -1) continue;
|
||||
int radix;
|
||||
get_radix(radix, key, lo_bit);
|
||||
int offset = offsets[i+blk_offset] + shr_offs[radix];
|
||||
@ -363,6 +372,7 @@ void radix_sort(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
|
||||
int key = defer[i];
|
||||
if (ignore_max && key == -1) continue;
|
||||
int radix;
|
||||
get_radix(radix, key, lo_bit);
|
||||
int offset = shr_offs[radix] + i;
|
||||
@ -430,12 +440,17 @@ class Sorter(object):
|
||||
RuntimeWarning, stacklevel=3)
|
||||
self.warn_issued = True
|
||||
|
||||
def sort(self, dst, src, size, lo_bit=0,
|
||||
def sort(self, dst, src, size, lo_bit=0, ignore_max=False,
|
||||
prev_lo_bit=None, prev_bits=None, stream=None):
|
||||
"""
|
||||
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
|
||||
the LSB. Store the result in 'dst'.
|
||||
|
||||
If 'ignore_max' is True, any key with the value 0xffffffff will be
|
||||
effeciently discarded. The number of valid results in the final array
|
||||
can be determined by examining the last item in the device array
|
||||
pointed to by this class's 'dglobal' property.
|
||||
|
||||
GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
|
||||
with the multi-pass repair kernel. This means that multi-pass sort is
|
||||
bugged. Don't use it unless your application can handle non-monotonic
|
||||
@ -452,7 +467,8 @@ class Sorter(object):
|
||||
grids = size / self.group_size
|
||||
|
||||
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
|
||||
block=(512, 1, 1), grid=(grids, 1), stream=stream)
|
||||
np.int32(ignore_max), block=(512, 1, 1),
|
||||
grid=(grids, 1), stream=stream)
|
||||
|
||||
# Intentionally ignore prev_bits=0
|
||||
if prev_bits:
|
||||
@ -483,8 +499,8 @@ class Sorter(object):
|
||||
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
|
||||
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
||||
|
||||
self.radix_sort(dst, src,
|
||||
self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
|
||||
self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals,
|
||||
np.int32(lo_bit), np.int32(ignore_max),
|
||||
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
|
||||
|
||||
def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,
|
||||
|
Loading…
Reference in New Issue
Block a user