mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Add -1-skipping to sort.
This commit is contained in:
parent
54f411878b
commit
05e1d08681
@ -35,7 +35,8 @@ void prefix_scan(
|
|||||||
int *offsets,
|
int *offsets,
|
||||||
int *pfxs,
|
int *pfxs,
|
||||||
const unsigned int *keys,
|
const unsigned int *keys,
|
||||||
const int lo_bit
|
const int lo_bit,
|
||||||
|
const int ignore_max
|
||||||
) {
|
) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
__shared__ int shr_pfxs[RDXSZ];
|
__shared__ int shr_pfxs[RDXSZ];
|
||||||
@ -56,9 +57,11 @@ void prefix_scan(
|
|||||||
// TODO: load 2 at once, compute, use a BFI to pack the two offsets
|
// TODO: load 2 at once, compute, use a BFI to pack the two offsets
|
||||||
// into an int to halve storage / bandwidth
|
// into an int to halve storage / bandwidth
|
||||||
int key = keys[idx];
|
int key = keys[idx];
|
||||||
int radix;
|
if (!ignore_max || key != -1) {
|
||||||
get_radix(radix, key, lo_bit);
|
int radix;
|
||||||
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
|
get_radix(radix, key, lo_bit);
|
||||||
|
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
|
||||||
|
}
|
||||||
idx += BLKSZ;
|
idx += BLKSZ;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,13 +275,14 @@ void prefix_sum_inner(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
sum = 0;
|
sum = 0;
|
||||||
|
|
||||||
// Intentionally exclusive indexing here
|
// Intentionally exclusive indexing here, fixed below
|
||||||
for (int i = 0; i < tid; i++) sum += sums[i];
|
for (int i = 0; i < tid; i++) sum += sums[i];
|
||||||
|
glob_pfxs[tid] = sum + sums[tid];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
sums[tid] = glob_pfxs[tid] = sum;
|
sums[tid] = sum;
|
||||||
idx = tid;
|
|
||||||
|
|
||||||
|
idx = tid;
|
||||||
for (int i = 0; i < ncondensed; i++) {
|
for (int i = 0; i < ncondensed; i++) {
|
||||||
int c = condensed[idx];
|
int c = condensed[idx];
|
||||||
condensed[idx] = sum;
|
condensed[idx] = sum;
|
||||||
@ -336,7 +340,8 @@ void radix_sort(
|
|||||||
const int *offsets,
|
const int *offsets,
|
||||||
const int *pfxs,
|
const int *pfxs,
|
||||||
const int *locals,
|
const int *locals,
|
||||||
const int lo_bit
|
const int lo_bit,
|
||||||
|
const int ignore_max
|
||||||
) {
|
) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int blk_offset = GRPSZ * blockIdx.x;
|
const int blk_offset = GRPSZ * blockIdx.x;
|
||||||
@ -347,8 +352,12 @@ void radix_sort(
|
|||||||
if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i];
|
if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
if (ignore_max)
|
||||||
|
for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1;
|
||||||
|
|
||||||
for (int i = tid; i < GRPSZ; i += BLKSZ) {
|
for (int i = tid; i < GRPSZ; i += BLKSZ) {
|
||||||
int key = keys[i+blk_offset];
|
int key = keys[i+blk_offset];
|
||||||
|
if (ignore_max && key == -1) continue;
|
||||||
int radix;
|
int radix;
|
||||||
get_radix(radix, key, lo_bit);
|
get_radix(radix, key, lo_bit);
|
||||||
int offset = offsets[i+blk_offset] + shr_offs[radix];
|
int offset = offsets[i+blk_offset] + shr_offs[radix];
|
||||||
@ -363,6 +372,7 @@ void radix_sort(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
|
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
|
||||||
int key = defer[i];
|
int key = defer[i];
|
||||||
|
if (ignore_max && key == -1) continue;
|
||||||
int radix;
|
int radix;
|
||||||
get_radix(radix, key, lo_bit);
|
get_radix(radix, key, lo_bit);
|
||||||
int offset = shr_offs[radix] + i;
|
int offset = shr_offs[radix] + i;
|
||||||
@ -430,12 +440,17 @@ class Sorter(object):
|
|||||||
RuntimeWarning, stacklevel=3)
|
RuntimeWarning, stacklevel=3)
|
||||||
self.warn_issued = True
|
self.warn_issued = True
|
||||||
|
|
||||||
def sort(self, dst, src, size, lo_bit=0,
|
def sort(self, dst, src, size, lo_bit=0, ignore_max=False,
|
||||||
prev_lo_bit=None, prev_bits=None, stream=None):
|
prev_lo_bit=None, prev_bits=None, stream=None):
|
||||||
"""
|
"""
|
||||||
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
|
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
|
||||||
the LSB. Store the result in 'dst'.
|
the LSB. Store the result in 'dst'.
|
||||||
|
|
||||||
|
If 'ignore_max' is True, any key with the value 0xffffffff will be
|
||||||
|
effeciently discarded. The number of valid results in the final array
|
||||||
|
can be determined by examining the last item in the device array
|
||||||
|
pointed to by this class's 'dglobal' property.
|
||||||
|
|
||||||
GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
|
GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
|
||||||
with the multi-pass repair kernel. This means that multi-pass sort is
|
with the multi-pass repair kernel. This means that multi-pass sort is
|
||||||
bugged. Don't use it unless your application can handle non-monotonic
|
bugged. Don't use it unless your application can handle non-monotonic
|
||||||
@ -452,7 +467,8 @@ class Sorter(object):
|
|||||||
grids = size / self.group_size
|
grids = size / self.group_size
|
||||||
|
|
||||||
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
|
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
|
||||||
block=(512, 1, 1), grid=(grids, 1), stream=stream)
|
np.int32(ignore_max), block=(512, 1, 1),
|
||||||
|
grid=(grids, 1), stream=stream)
|
||||||
|
|
||||||
# Intentionally ignore prev_bits=0
|
# Intentionally ignore prev_bits=0
|
||||||
if prev_bits:
|
if prev_bits:
|
||||||
@ -483,8 +499,8 @@ class Sorter(object):
|
|||||||
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
|
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
|
||||||
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
||||||
|
|
||||||
self.radix_sort(dst, src,
|
self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals,
|
||||||
self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
|
np.int32(lo_bit), np.int32(ignore_max),
|
||||||
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
|
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
|
||||||
|
|
||||||
def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,
|
def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,
|
||||||
|
Loading…
Reference in New Issue
Block a user