mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-10-31 17:30:46 -04:00 
			
		
		
		
	Add -1-skipping to sort.
This commit is contained in:
		| @ -35,7 +35,8 @@ void prefix_scan( | |||||||
|         int *offsets, |         int *offsets, | ||||||
|         int *pfxs, |         int *pfxs, | ||||||
|         const unsigned int *keys, |         const unsigned int *keys, | ||||||
|         const int lo_bit |         const int lo_bit, | ||||||
|  |         const int ignore_max | ||||||
| ) { | ) { | ||||||
|     const int tid = threadIdx.x; |     const int tid = threadIdx.x; | ||||||
|     __shared__ int shr_pfxs[RDXSZ]; |     __shared__ int shr_pfxs[RDXSZ]; | ||||||
| @ -56,9 +57,11 @@ void prefix_scan( | |||||||
|         // TODO: load 2 at once, compute, use a BFI to pack the two offsets |         // TODO: load 2 at once, compute, use a BFI to pack the two offsets | ||||||
|         //       into an int to halve storage / bandwidth |         //       into an int to halve storage / bandwidth | ||||||
|         int key = keys[idx]; |         int key = keys[idx]; | ||||||
|  |         if (!ignore_max || key != -1) { | ||||||
|             int radix; |             int radix; | ||||||
|             get_radix(radix, key, lo_bit); |             get_radix(radix, key, lo_bit); | ||||||
|                 offsets[idx] = atomicAdd(shr_pfxs + radix, 1); |                 offsets[idx] = atomicAdd(shr_pfxs + radix, 1); | ||||||
|  |         } | ||||||
|         idx += BLKSZ; |         idx += BLKSZ; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @ -272,13 +275,14 @@ void prefix_sum_inner( | |||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|     sum = 0; |     sum = 0; | ||||||
|  |  | ||||||
|     // Intentionally exclusive indexing here |     // Intentionally exclusive indexing here, fixed below | ||||||
|     for (int i = 0; i < tid; i++) sum += sums[i]; |     for (int i = 0; i < tid; i++) sum += sums[i]; | ||||||
|  |     glob_pfxs[tid] = sum + sums[tid]; | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |  | ||||||
|     sums[tid] = glob_pfxs[tid] = sum; |     sums[tid] = sum; | ||||||
|     idx = tid; |  | ||||||
|  |  | ||||||
|  |     idx = tid; | ||||||
|     for (int i = 0; i < ncondensed; i++) { |     for (int i = 0; i < ncondensed; i++) { | ||||||
|         int c = condensed[idx]; |         int c = condensed[idx]; | ||||||
|         condensed[idx] = sum; |         condensed[idx] = sum; | ||||||
| @ -336,7 +340,8 @@ void radix_sort( | |||||||
|         const int *offsets, |         const int *offsets, | ||||||
|         const int *pfxs, |         const int *pfxs, | ||||||
|         const int *locals, |         const int *locals, | ||||||
|         const int lo_bit |         const int lo_bit, | ||||||
|  |         const int ignore_max | ||||||
| ) { | ) { | ||||||
|     const int tid = threadIdx.x; |     const int tid = threadIdx.x; | ||||||
|     const int blk_offset = GRPSZ * blockIdx.x; |     const int blk_offset = GRPSZ * blockIdx.x; | ||||||
| @ -347,8 +352,12 @@ void radix_sort( | |||||||
|     if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i]; |     if (tid < RDXSZ) shr_offs[tid] = locals[pfx_i]; | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     if (ignore_max) | ||||||
|  |         for (int i = tid; i < GRPSZ; i += BLKSZ) defer[i] = -1; | ||||||
|  |  | ||||||
|     for (int i = tid; i < GRPSZ; i += BLKSZ) { |     for (int i = tid; i < GRPSZ; i += BLKSZ) { | ||||||
|         int key = keys[i+blk_offset]; |         int key = keys[i+blk_offset]; | ||||||
|  |         if (ignore_max && key == -1) continue; | ||||||
|         int radix; |         int radix; | ||||||
|         get_radix(radix, key, lo_bit); |         get_radix(radix, key, lo_bit); | ||||||
|         int offset = offsets[i+blk_offset] + shr_offs[radix]; |         int offset = offsets[i+blk_offset] + shr_offs[radix]; | ||||||
| @ -363,6 +372,7 @@ void radix_sort( | |||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int j = 0; j < GRP_BLK_FACTOR; j++) { |     for (int j = 0; j < GRP_BLK_FACTOR; j++) { | ||||||
|         int key = defer[i]; |         int key = defer[i]; | ||||||
|  |         if (ignore_max && key == -1) continue; | ||||||
|         int radix; |         int radix; | ||||||
|         get_radix(radix, key, lo_bit); |         get_radix(radix, key, lo_bit); | ||||||
|         int offset = shr_offs[radix] + i; |         int offset = shr_offs[radix] + i; | ||||||
| @ -430,12 +440,17 @@ class Sorter(object): | |||||||
|                           RuntimeWarning, stacklevel=3) |                           RuntimeWarning, stacklevel=3) | ||||||
|             self.warn_issued = True |             self.warn_issued = True | ||||||
|  |  | ||||||
|     def sort(self, dst, src, size, lo_bit=0, |     def sort(self, dst, src, size, lo_bit=0, ignore_max=False, | ||||||
|              prev_lo_bit=None, prev_bits=None, stream=None): |              prev_lo_bit=None, prev_bits=None, stream=None): | ||||||
|         """ |         """ | ||||||
|         Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is |         Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is | ||||||
|         the LSB. Store the result in 'dst'. |         the LSB. Store the result in 'dst'. | ||||||
|  |  | ||||||
|  |         If 'ignore_max' is True, any key with the value 0xffffffff will be | ||||||
|  |         effeciently discarded. The number of valid results in the final array | ||||||
|  |         can be determined by examining the last item in the device array | ||||||
|  |         pointed to by this class's 'dglobal' property. | ||||||
|  |  | ||||||
|         GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even |         GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even | ||||||
|         with the multi-pass repair kernel. This means that multi-pass sort is |         with the multi-pass repair kernel. This means that multi-pass sort is | ||||||
|         bugged. Don't use it unless your application can handle non-monotonic |         bugged. Don't use it unless your application can handle non-monotonic | ||||||
| @ -452,7 +467,8 @@ class Sorter(object): | |||||||
|         grids = size / self.group_size |         grids = size / self.group_size | ||||||
|  |  | ||||||
|         self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit), |         self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit), | ||||||
|             block=(512, 1, 1), grid=(grids, 1), stream=stream) |                          np.int32(ignore_max), block=(512, 1, 1), | ||||||
|  |                          grid=(grids, 1), stream=stream) | ||||||
|  |  | ||||||
|         # Intentionally ignore prev_bits=0 |         # Intentionally ignore prev_bits=0 | ||||||
|         if prev_bits: |         if prev_bits: | ||||||
| @ -483,8 +499,8 @@ class Sorter(object): | |||||||
|         self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth, |         self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth, | ||||||
|             block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream) |             block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream) | ||||||
|  |  | ||||||
|         self.radix_sort(dst, src, |         self.radix_sort(dst, src, self.doffsets, self.dpfxs, self.dlocals, | ||||||
|             self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit), |             np.int32(lo_bit), np.int32(ignore_max), | ||||||
|             block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream) |             block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream) | ||||||
|  |  | ||||||
|     def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0, |     def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0, | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Steven Robertson
					Steven Robertson