From 54f411878bc01cd3f11dc7407435f61b4fdf3e92 Mon Sep 17 00:00:00 2001 From: Steven Robertson Date: Thu, 10 Nov 2011 10:49:35 -0500 Subject: [PATCH] Experiments with multi-pass sort (still has bugs) --- cuburn/code/sort.py | 267 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 225 insertions(+), 42 deletions(-) diff --git a/cuburn/code/sort.py b/cuburn/code/sort.py index bd7854f..b0536f2 100644 --- a/cuburn/code/sort.py +++ b/cuburn/code/sort.py @@ -1,4 +1,6 @@ +import warnings + import numpy as np import pycuda.driver as cuda import pycuda.compiler @@ -18,7 +20,16 @@ _CODE = tempita.Template(r""" #define get_radix(r, k, l) \ asm("bfe.u32 %0, %1, %2, {{radix_bits}};" : "=r"(r) : "r"(k), "r"(l)) -// TODO: experiment with different block / group sizes +// This kernel conducts a prefix scan of the 'keys' array. As each radix is +// read, it is immediately added to the corresponding accumulator. The +// resulting value is unique among keys in this block with the same radix. We +// will use this offset later (along with the more typical prefix sums) to +// insert the key into a shared memory array by radix, so that they can be +// written in fewer transactions to global memory (which is important given +// the larger radix sizes used here). +// +// Note that the indices generated here are unique but not necessarily +// monotonic, so using these directly leads to a mildly unstable sort. __global__ void prefix_scan( int *offsets, @@ -40,10 +51,10 @@ void prefix_scan( __syncthreads(); int idx = tid + GRPSZ * blockIdx.x; +#pragma unroll for (int i = 0; i < GRP_BLK_FACTOR; i++) { // TODO: load 2 at once, compute, use a BFI to pack the two offsets // into an int to halve storage / bandwidth - // TODO: separate or integrated loop vars? unrolling? int key = keys[idx]; int radix; get_radix(radix, key, lo_bit); @@ -60,7 +71,108 @@ void prefix_scan( pfxs[tid + {{i}} + RDXSZ * blockIdx.x] = shr_pfxs[tid + {{i}}]; {{endfor}} {{endif}} +} +// Populate 'indices' so that indices[i] contains the largest value 'x' such +// that keys[x] < i. If no value is smaller than i, indices[i] is 0. +__global__ +void binary_search( + int *indices, + const unsigned int *keys, + const int prev_lo_bit, + const int mask, + const int length +) { + int tid_full = blockIdx.x * blockDim.x + threadIdx.x; + int target = tid_full << prev_lo_bit; + + int lo = 0; + + // Length must be a power of two! (Guaranteed by runtime.) + for (int i = length >> 1; i > 0; i >>= 1) { + int mid = lo | i; + if (target > (keys[mid] & mask)) lo = mid; + } + + indices[tid_full] = lo; +} + +// When performing a sort by repeatedly applying smaller sorts, the non-stable +// nature of the prefix scan done above will cause errors in the output. These +// errors only affect the correctness of the sort inside those groups which +// cover a transition in the radix of the previous sort passes. We re-run +// those groups with a more careful algorithm here. This doesn't make the sort +// stable in general, but it's enough to make multi-pass sorts correct. +// (Or it should be, although it seems there's a bug either in this code or in +// my head.) +__global__ +void prefix_scan_repair( + int *offsets, + int *pfxs, + const unsigned int *keys, + const unsigned int *trans_points, + const int lo_bit, + const int prev_lo_bit, + const int mask +) { + const int tid = threadIdx.x; + const int blkid = blockIdx.y * gridDim.x + blockIdx.x; + __shared__ int shr_pfxs[RDXSZ]; + __shared__ int blk_starts[GRP_BLK_FACTOR]; + __shared__ int blk_transitions[GRP_BLK_FACTOR]; + + // Never need to repair the start of the array. + if (blkid == 0) return; + + // Get the start index of the group to repair. + int grp_start = trans_points[blkid] & ~(GRPSZ - 1); + + // If the largest prev_radix in this block is not equal to than our blkid, + // it means that another thread block will also attend to the same thread + // block (in which case we cede to it), or the transition point happens to + // be on a group boundary. In either case, we should bail. + // + // Note that prev_* holds a masked but not shifted value. + int prev_max = keys[grp_start + (GRPSZ - 1)] & mask; + if (prev_max != (blkid << prev_lo_bit)) return; + + int prev_incr = 1 << prev_lo_bit; + + // For each block of keys that this thread block will analyze, determine + // how many transitions occur within that block. + if (tid < GRP_BLK_FACTOR) { + int prev_lo = keys[grp_start + tid * BLKSZ] & mask; + int prev_hi = keys[grp_start + tid * BLKSZ + BLKSZ - 1] & mask; + blk_starts[tid] = prev_lo; + blk_transitions[tid] = (prev_hi - prev_lo) >> prev_lo_bit; + } + +{{if radix_size <= 512}} + if (tid < RDXSZ) shr_pfxs[tid] = 0; +{{else}} +{{for i in range(0, radix_size, 512)}} + shr_pfxs[tid+{{i}}] = 0; +{{endfor}} +{{endif}} + __syncthreads(); + + int idx = grp_start + tid; + for (int i = 0; i < GRP_BLK_FACTOR; i++) { + int key = keys[idx]; + int prev_radix = blk_starts[i]; + int this_prev_radix = key & mask; + int radix; + get_radix(radix, key, lo_bit); + if (this_prev_radix == prev_radix) + offsets[idx] = atomicAdd(shr_pfxs + radix, 1); + for (int j = 0; j < blk_transitions[i]; j++) { + __syncthreads(); + prev_radix += prev_incr; + if (this_prev_radix == prev_radix) + offsets[idx] = atomicAdd(shr_pfxs + radix, 1); + } + idx += BLKSZ; + } } // Calculate group-local exclusive prefix sums (the number of keys in the @@ -265,6 +377,8 @@ class Sorter(object): group_size = 8192 radix_bits = 8 + warn_issued = False + @classmethod def init_mod(cls): if cls.__dict__.get('mod') is None: @@ -276,20 +390,31 @@ class Sorter(object): with open('/tmp/sort_kern.cubin', 'wb') as fp: fp.write(cubin) for name in ['prefix_scan', 'prefix_sum_condense', - 'prefix_sum_inner', 'prefix_sum_distribute']: + 'prefix_sum_inner', 'prefix_sum_distribute', + 'binary_search', 'prefix_scan_repair']: f = cls.mod.get_function(name) setattr(cls, name, f) f.set_cache_config(cuda.func_cache.PREFER_L1) cls.calc_local_pfxs = cls.mod.get_function('calc_local_pfxs') cls.radix_sort = cls.mod.get_function('radix_sort') - def __init__(self, max_size): + def __init__(self, max_size, offsets=None): + """ + Create a sorter. The sorter will hold on to internal resources for as + long as it is alive, including an 'offsets' array of size 4*max_size. + To share this cost, you may pass in an array of at least this size to + __init__ (to, for instance, share across different bit-widths in a + multi-pass sort). + """ self.init_mod() self.max_size = max_size assert max_size % self.group_size == 0 max_grids = max_size / self.group_size - self.doffsets = cuda.mem_alloc(self.max_size * 4) + if offsets is None: + self.doffsets = cuda.mem_alloc(self.max_size * 4) + else: + self.doffsets = offsets self.dpfxs = cuda.mem_alloc(max_grids * self.radix_size * 4) self.dlocals = cuda.mem_alloc(max_grids * self.radix_size * 4) @@ -299,15 +424,28 @@ class Sorter(object): self.dcond = cuda.mem_alloc(self.radix_size * self.ncond * 4) self.dglobal = cuda.mem_alloc(self.radix_size * 4) - def sort(self, dst, src, size, lo_bit=0, stream=None): + def warn(self): + if not self.warn_issued: + warnings.warn('You know multi-pass is broken, right?', + RuntimeWarning, stacklevel=3) + self.warn_issued = True + + def sort(self, dst, src, size, lo_bit=0, + prev_lo_bit=None, prev_bits=None, stream=None): """ Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is the LSB. Store the result in 'dst'. - Note that this is *not* a stable sort! It won't jumble your data - haphazardly, but one- or two-position swaps are very common. This will - hopefully be resolved soon, but until then, it is unsuitable for - implementing larger sorts from multiple passes of this sort. + GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even + with the multi-pass repair kernel. This means that multi-pass sort is + bugged. Don't use it unless your application can handle non-monotonic + values in your "sorted" array. + + To perform a multi-pass sort, pass 'prev_lo_bit' and 'prev_bits', + indicating the lowest bit considered across the entire sort and the + number of bits previously sorted. This uses 2^(prev_bits+2) bytes of + memory and performs about group_size*2^(prev_bits+1) extra operations, + so it's useful up to three passes but probably not four. """ assert size <= self.max_size and size % self.group_size == 0 @@ -316,6 +454,22 @@ class Sorter(object): self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit), block=(512, 1, 1), grid=(grids, 1), stream=stream) + # Intentionally ignore prev_bits=0 + if prev_bits: + self.warn() + assert not (size & (size - 1)), \ + 'Size must be a power of two, due to my enduring laziness' + didx = cuda.mem_alloc(4 << prev_bits) + mask = np.uint32(((1 << prev_bits) - 1) << prev_lo_bit) + self.binary_search(didx, src, np.int32(prev_lo_bit), + mask, np.int32(size), + block=(128, 1, 1), grid=((1< 1: + self.warn() + for i in range(rounds): + b = i * self.radix_bits + self.sort(scratch_a, src, size, lo_bit + b, lo_bit, b, stream) + scratch_a, scratch_b, src = scratch_b, scratch_a, scratch_a + return src + @classmethod def test(cls, count, correctness=False): keys = np.uint32(np.random.randint(0, 1<> 13)) + for i in nz: + print i, out[i-1:i+2], sort[i-1:i+2] + assert False, 'Oh no' - print '\n\n%d bit sort' % cls.radix_bits - print 'Testing speed' - test_stub(0) - if '-s' not in sys.argv: - print '\nTesting correctness' - out = cuda.from_device(dout, (count,), np.uint32) - sort = np.sort(keys) - if np.all(out == sort): - print 'Correct' - else: - assert False, 'Oh no' - - print '\nTesting speed at shifts' - for b in range(cls.radix_bits - 1): - print 'Performance with %d sig bits' % (cls.radix_bits - b) + for b in range(cls.radix_bits - 3): + print '%2d (%2d sig bits),\t' % (cls.radix_bits, cls.radix_bits - b), test_stub(b) + if not correctness: + for r in range(2,3): + keys[:] = np.uint32( + np.random.randint(0, 1<<(cls.radix_bits*r), count)) + cuda.memcpy_htod(dkeys, keys) + print '%2d x %d,\t\t\t' % (cls.radix_bits, r), + test_stub(0, rounds=r) + print + if __name__ == "__main__": import sys import pycuda.autoinit - np.set_printoptions(precision=5, edgeitems=20, - linewidth=100, threshold=90) - count = 1 << 26 + np.set_printoptions(precision=5, edgeitems=200, + linewidth=95, threshold=9000) + count = 1 << 25 np.random.seed(42) - correct = '-s' not in sys.argv + correct = '-c' in sys.argv for g in (8192, 4096): - print '\n\n== GROUP SIZE %d ==' % g + print '\n\n== GROUP SIZE %d ==,\t msec,\tMK/s,\tMK/s norm' % g Sorter.group_size = g for b in [7,8,9,10]: if g == 4096 and b == 10: continue