mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Add -1-skipping to sort.
This commit is contained in:
		@ -35,7 +35,8 @@ void prefix_scan(
 | 
			
		||||
        int *offsets,
 | 
			
		||||
        int *pfxs,
 | 
			
		||||
        const unsigned int *keys,
 | 
			
		||||
        const int lo_bit
 | 
			
		||||
        const int lo_bit,
 | 
			
		||||
        const int ignore_max
 | 
			
		||||
) {
 | 
			
		||||
    const int tid = threadIdx.x;
 | 
			
		||||
    __shared__ int shr_pfxs[RDXSZ];
 | 
			
		||||
@ -56,9 +57,11 @@ void prefix_scan(
 | 
			
		||||
        // TODO: load 2 at once, compute, use a BFI to pack the two offsets
 | 
			
		||||
        //       into an int to halve storage / bandwidth
 | 
			
		||||
        int key = keys[idx];
 | 
			
		||||
        int radix;
 | 
			
		||||
        get_radix(radix, key, lo_bit);
 | 
			
		||||
        offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
 | 
			
		||||
        if (!ignore_max || key != -1) {
 | 
			
		||||
            int radix;
 | 
			
		||||
            get_radix(radix, key, lo_bit);
 | 
			
		||||
                offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
 | 
			
		||||
        }
 | 
			
		||||
        idx += BLKSZ;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -272,13 +275,14 @@ void prefix_sum_inner(
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Intentionally exclusive indexing here
 | 
			
		||||
    // Intentionally exclusive indexing here, fixed below
 | 
			
		||||
    for (int i = 0; i < tid; i++) sum += sums[i];
 | 
			
		||||
    glob_pfxs[tid] = sum + sums[tid];
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    sums[tid] = glob_pfxs[tid] = sum;
 | 
			
		||||
    idx = tid;
 | 
			
		||||
    sums[tid] = sum;
 | 
			
		||||
 | 
			
		||||
    idx = tid;
 | 
			
		||||
    for (int i = 0; i < ncondensed; i++) {
 | 
			
		||||
        int c = condensed[idx];
 | 
			
		||||
        condensed[idx] = sum;
 | 
			
		||||
@ -336,7 +340,8 @@ void radix_sort(
 | 
			
		||||
        const int *offsets,
 | 
			
		||||
        const int *pfxs,
 | 
			
		||||
        const int *locals,
 | 
			
		||||
        const int lo_bit
 | 
			
		||||
        const int lo_bit,
 | 
			
		||||
        const int ignore_max
 | 
			
		||||
) {
 | 
			
		||||
    const int tid = threadIdx.x;
 | 
			
		||||
    const int blk_offset = GRPSZ * blockIdx.x;
 | 
			
		||||
@ -347,8 +352,12 @@ void radix_sort(
 | 
			
		||||
    if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i];
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    if (ignore_max)
 | 
			
		||||
        for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1;
 | 
			
		||||
 | 
			
		||||
    for (int i = tid; i < GRPSZ; i += BLKSZ) {
 | 
			
		||||
        int key = keys[i+blk_offset];
 | 
			
		||||
        if (ignore_max && key == -1) continue;
 | 
			
		||||
        int radix;
 | 
			
		||||
        get_radix(radix, key, lo_bit);
 | 
			
		||||
        int offset = offsets[i+blk_offset] + shr_offs[radix];
 | 
			
		||||
@ -363,6 +372,7 @@ void radix_sort(
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int j = 0; j < GRP_BLK_FACTOR; j++) {
 | 
			
		||||
        int key = defer[i];
 | 
			
		||||
        if (ignore_max && key == -1) continue;
 | 
			
		||||
        int radix;
 | 
			
		||||
        get_radix(radix, key, lo_bit);
 | 
			
		||||
        int offset = shr_offs[radix] + i;
 | 
			
		||||
@ -430,12 +440,17 @@ class Sorter(object):
 | 
			
		||||
                          RuntimeWarning, stacklevel=3)
 | 
			
		||||
            self.warn_issued = True
 | 
			
		||||
 | 
			
		||||
    def sort(self, dst, src, size, lo_bit=0,
 | 
			
		||||
    def sort(self, dst, src, size, lo_bit=0, ignore_max=False,
 | 
			
		||||
             prev_lo_bit=None, prev_bits=None, stream=None):
 | 
			
		||||
        """
 | 
			
		||||
        Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
 | 
			
		||||
        the LSB. Store the result in 'dst'.
 | 
			
		||||
 | 
			
		||||
        If 'ignore_max' is True, any key with the value 0xffffffff will be
 | 
			
		||||
        effeciently discarded. The number of valid results in the final array
 | 
			
		||||
        can be determined by examining the last item in the device array
 | 
			
		||||
        pointed to by this class's 'dglobal' property.
 | 
			
		||||
 | 
			
		||||
        GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
 | 
			
		||||
        with the multi-pass repair kernel. This means that multi-pass sort is
 | 
			
		||||
        bugged. Don't use it unless your application can handle non-monotonic
 | 
			
		||||
@ -452,7 +467,8 @@ class Sorter(object):
 | 
			
		||||
        grids = size / self.group_size
 | 
			
		||||
 | 
			
		||||
        self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
 | 
			
		||||
            block=(512, 1, 1), grid=(grids, 1), stream=stream)
 | 
			
		||||
                         np.int32(ignore_max), block=(512, 1, 1),
 | 
			
		||||
                         grid=(grids, 1), stream=stream)
 | 
			
		||||
 | 
			
		||||
        # Intentionally ignore prev_bits=0
 | 
			
		||||
        if prev_bits:
 | 
			
		||||
@ -483,8 +499,8 @@ class Sorter(object):
 | 
			
		||||
        self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
 | 
			
		||||
            block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
 | 
			
		||||
 | 
			
		||||
        self.radix_sort(dst, src,
 | 
			
		||||
            self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
 | 
			
		||||
        self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals,
 | 
			
		||||
            np.int32(lo_bit), np.int32(ignore_max),
 | 
			
		||||
            block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
 | 
			
		||||
 | 
			
		||||
    def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user