mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Experiments with multi-pass sort (still has bugs)
This commit is contained in:
		@ -1,4 +1,6 @@
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pycuda.driver as cuda
 | 
			
		||||
import pycuda.compiler
 | 
			
		||||
@ -18,7 +20,16 @@ _CODE = tempita.Template(r"""
 | 
			
		||||
#define get_radix(r, k, l) \
 | 
			
		||||
        asm("bfe.u32 %0, %1, %2, {{radix_bits}};" : "=r"(r) : "r"(k), "r"(l))
 | 
			
		||||
 | 
			
		||||
// TODO: experiment with different block / group sizes
 | 
			
		||||
// This kernel conducts a prefix scan of the 'keys' array. As each radix is
 | 
			
		||||
// read, it is immediately added to the corresponding accumulator. The
 | 
			
		||||
// resulting value is unique among keys in this block with the same radix. We
 | 
			
		||||
// will use this offset later (along with the more typical prefix sums) to
 | 
			
		||||
// insert the key into a shared memory array by radix, so that they can be
 | 
			
		||||
// written in fewer transactions to global memory (which is important given
 | 
			
		||||
// the larger radix sizes used here).
 | 
			
		||||
//
 | 
			
		||||
// Note that the indices generated here are unique but not necessarily
 | 
			
		||||
// monotonic, so using these directly leads to a mildly unstable sort.
 | 
			
		||||
__global__
 | 
			
		||||
void prefix_scan(
 | 
			
		||||
        int *offsets,
 | 
			
		||||
@ -40,10 +51,10 @@ void prefix_scan(
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    int idx = tid + GRPSZ * blockIdx.x;
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int i = 0; i < GRP_BLK_FACTOR; i++) {
 | 
			
		||||
        // TODO: load 2 at once, compute, use a BFI to pack the two offsets
 | 
			
		||||
        //       into an int to halve storage / bandwidth
 | 
			
		||||
        // TODO: separate or integrated loop vars? unrolling?
 | 
			
		||||
        int key = keys[idx];
 | 
			
		||||
        int radix;
 | 
			
		||||
        get_radix(radix, key, lo_bit);
 | 
			
		||||
@ -60,7 +71,108 @@ void prefix_scan(
 | 
			
		||||
    pfxs[tid + {{i}} + RDXSZ * blockIdx.x] = shr_pfxs[tid + {{i}}];
 | 
			
		||||
{{endfor}}
 | 
			
		||||
{{endif}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Populate 'indices' so that indices[i] contains the largest value 'x' such
 | 
			
		||||
// that keys[x] < i. If no value is smaller than i, indices[i] is 0.
 | 
			
		||||
__global__
 | 
			
		||||
void binary_search(
 | 
			
		||||
        int *indices,
 | 
			
		||||
        const unsigned int *keys,
 | 
			
		||||
        const int prev_lo_bit,
 | 
			
		||||
        const int mask,
 | 
			
		||||
        const int length
 | 
			
		||||
) {
 | 
			
		||||
    int tid_full = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
    int target = tid_full << prev_lo_bit;
 | 
			
		||||
 | 
			
		||||
    int lo = 0;
 | 
			
		||||
 | 
			
		||||
    // Length must be a power of two! (Guaranteed by runtime.)
 | 
			
		||||
    for (int i = length >> 1; i > 0; i >>= 1) {
 | 
			
		||||
        int mid = lo | i;
 | 
			
		||||
        if (target > (keys[mid] & mask)) lo = mid;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    indices[tid_full] = lo;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// When performing a sort by repeatedly applying smaller sorts, the non-stable
 | 
			
		||||
// nature of the prefix scan done above will cause errors in the output. These
 | 
			
		||||
// errors only affect the correctness of the sort inside those groups which
 | 
			
		||||
// cover a transition in the radix of the previous sort passes. We re-run
 | 
			
		||||
// those groups with a more careful algorithm here. This doesn't make the sort
 | 
			
		||||
// stable in general, but it's enough to make multi-pass sorts correct.
 | 
			
		||||
// (Or it should be, although it seems there's a bug either in this code or in
 | 
			
		||||
// my head.)
 | 
			
		||||
__global__
 | 
			
		||||
void prefix_scan_repair(
 | 
			
		||||
        int *offsets,
 | 
			
		||||
        int *pfxs,
 | 
			
		||||
        const unsigned int *keys,
 | 
			
		||||
        const unsigned int *trans_points,
 | 
			
		||||
        const int lo_bit,
 | 
			
		||||
        const int prev_lo_bit,
 | 
			
		||||
        const int mask
 | 
			
		||||
) {
 | 
			
		||||
    const int tid = threadIdx.x;
 | 
			
		||||
    const int blkid = blockIdx.y * gridDim.x + blockIdx.x;
 | 
			
		||||
    __shared__ int shr_pfxs[RDXSZ];
 | 
			
		||||
    __shared__ int blk_starts[GRP_BLK_FACTOR];
 | 
			
		||||
    __shared__ int blk_transitions[GRP_BLK_FACTOR];
 | 
			
		||||
 | 
			
		||||
    // Never need to repair the start of the array.
 | 
			
		||||
    if (blkid == 0) return;
 | 
			
		||||
 | 
			
		||||
    // Get the start index of the group to repair.
 | 
			
		||||
    int grp_start = trans_points[blkid] & ~(GRPSZ - 1);
 | 
			
		||||
 | 
			
		||||
    // If the largest prev_radix in this block is not equal to than our blkid,
 | 
			
		||||
    // it means that another thread block will also attend to the same thread
 | 
			
		||||
    // block (in which case we cede to it), or the transition point happens to
 | 
			
		||||
    // be on a group boundary. In either case, we should bail.
 | 
			
		||||
    //
 | 
			
		||||
    // Note that prev_* holds a masked but not shifted value.
 | 
			
		||||
    int prev_max = keys[grp_start + (GRPSZ - 1)] & mask;
 | 
			
		||||
    if (prev_max != (blkid << prev_lo_bit)) return;
 | 
			
		||||
 | 
			
		||||
    int prev_incr = 1 << prev_lo_bit;
 | 
			
		||||
 | 
			
		||||
    // For each block of keys that this thread block will analyze, determine
 | 
			
		||||
    // how many transitions occur within that block.
 | 
			
		||||
    if (tid < GRP_BLK_FACTOR) {
 | 
			
		||||
        int prev_lo = keys[grp_start + tid * BLKSZ] & mask;
 | 
			
		||||
        int prev_hi = keys[grp_start + tid * BLKSZ + BLKSZ - 1] & mask;
 | 
			
		||||
        blk_starts[tid] = prev_lo;
 | 
			
		||||
        blk_transitions[tid] = (prev_hi - prev_lo) >> prev_lo_bit;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
{{if radix_size <= 512}}
 | 
			
		||||
    if (tid < RDXSZ) shr_pfxs[tid] = 0;
 | 
			
		||||
{{else}}
 | 
			
		||||
{{for i in range(0, radix_size, 512)}}
 | 
			
		||||
    shr_pfxs[tid+{{i}}] = 0;
 | 
			
		||||
{{endfor}}
 | 
			
		||||
{{endif}}
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    int idx = grp_start + tid;
 | 
			
		||||
    for (int i = 0; i < GRP_BLK_FACTOR; i++) {
 | 
			
		||||
        int key = keys[idx];
 | 
			
		||||
        int prev_radix = blk_starts[i];
 | 
			
		||||
        int this_prev_radix = key & mask;
 | 
			
		||||
        int radix;
 | 
			
		||||
        get_radix(radix, key, lo_bit);
 | 
			
		||||
        if (this_prev_radix == prev_radix)
 | 
			
		||||
            offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
 | 
			
		||||
        for (int j = 0; j < blk_transitions[i]; j++) {
 | 
			
		||||
            __syncthreads();
 | 
			
		||||
            prev_radix += prev_incr;
 | 
			
		||||
            if (this_prev_radix == prev_radix)
 | 
			
		||||
                offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
 | 
			
		||||
        }
 | 
			
		||||
        idx += BLKSZ;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate group-local exclusive prefix sums (the number of keys in the
 | 
			
		||||
@ -265,6 +377,8 @@ class Sorter(object):
 | 
			
		||||
    group_size = 8192
 | 
			
		||||
    radix_bits = 8
 | 
			
		||||
 | 
			
		||||
    warn_issued = False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def init_mod(cls):
 | 
			
		||||
        if cls.__dict__.get('mod') is None:
 | 
			
		||||
@ -276,20 +390,31 @@ class Sorter(object):
 | 
			
		||||
            with open('/tmp/sort_kern.cubin', 'wb') as fp:
 | 
			
		||||
                fp.write(cubin)
 | 
			
		||||
            for name in ['prefix_scan', 'prefix_sum_condense',
 | 
			
		||||
                         'prefix_sum_inner', 'prefix_sum_distribute']:
 | 
			
		||||
                         'prefix_sum_inner', 'prefix_sum_distribute',
 | 
			
		||||
                         'binary_search', 'prefix_scan_repair']:
 | 
			
		||||
                f = cls.mod.get_function(name)
 | 
			
		||||
                setattr(cls, name, f)
 | 
			
		||||
                f.set_cache_config(cuda.func_cache.PREFER_L1)
 | 
			
		||||
            cls.calc_local_pfxs = cls.mod.get_function('calc_local_pfxs')
 | 
			
		||||
            cls.radix_sort = cls.mod.get_function('radix_sort')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, max_size):
 | 
			
		||||
    def __init__(self, max_size, offsets=None):
 | 
			
		||||
        """
 | 
			
		||||
        Create a sorter. The sorter will hold on to internal resources for as
 | 
			
		||||
        long as it is alive, including an 'offsets' array of size 4*max_size.
 | 
			
		||||
        To share this cost, you may pass in an array of at least this size to
 | 
			
		||||
        __init__ (to, for instance, share across different bit-widths in a
 | 
			
		||||
        multi-pass sort).
 | 
			
		||||
        """
 | 
			
		||||
        self.init_mod()
 | 
			
		||||
        self.max_size = max_size
 | 
			
		||||
        assert max_size % self.group_size == 0
 | 
			
		||||
        max_grids = max_size / self.group_size
 | 
			
		||||
 | 
			
		||||
        self.doffsets = cuda.mem_alloc(self.max_size * 4)
 | 
			
		||||
        if offsets is None:
 | 
			
		||||
            self.doffsets = cuda.mem_alloc(self.max_size * 4)
 | 
			
		||||
        else:
 | 
			
		||||
            self.doffsets = offsets
 | 
			
		||||
        self.dpfxs = cuda.mem_alloc(max_grids * self.radix_size * 4)
 | 
			
		||||
        self.dlocals = cuda.mem_alloc(max_grids * self.radix_size * 4)
 | 
			
		||||
 | 
			
		||||
@ -299,15 +424,28 @@ class Sorter(object):
 | 
			
		||||
        self.dcond = cuda.mem_alloc(self.radix_size * self.ncond * 4)
 | 
			
		||||
        self.dglobal = cuda.mem_alloc(self.radix_size * 4)
 | 
			
		||||
 | 
			
		||||
    def sort(self, dst, src, size, lo_bit=0, stream=None):
 | 
			
		||||
    def warn(self):
 | 
			
		||||
        if not self.warn_issued:
 | 
			
		||||
            warnings.warn('You know multi-pass is broken, right?',
 | 
			
		||||
                          RuntimeWarning, stacklevel=3)
 | 
			
		||||
            self.warn_issued = True
 | 
			
		||||
 | 
			
		||||
    def sort(self, dst, src, size, lo_bit=0,
 | 
			
		||||
             prev_lo_bit=None, prev_bits=None, stream=None):
 | 
			
		||||
        """
 | 
			
		||||
        Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
 | 
			
		||||
        the LSB. Store the result in 'dst'.
 | 
			
		||||
 | 
			
		||||
        Note that this is *not* a stable sort! It won't jumble your data
 | 
			
		||||
        haphazardly, but one- or two-position swaps are very common. This will
 | 
			
		||||
        hopefully be resolved soon, but until then, it is unsuitable for
 | 
			
		||||
        implementing larger sorts from multiple passes of this sort.
 | 
			
		||||
        GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
 | 
			
		||||
        with the multi-pass repair kernel. This means that multi-pass sort is
 | 
			
		||||
        bugged. Don't use it unless your application can handle non-monotonic
 | 
			
		||||
        values in your "sorted" array.
 | 
			
		||||
 | 
			
		||||
        To perform a multi-pass sort, pass 'prev_lo_bit' and 'prev_bits',
 | 
			
		||||
        indicating the lowest bit considered across the entire sort and the
 | 
			
		||||
        number of bits previously sorted. This uses 2^(prev_bits+2) bytes of
 | 
			
		||||
        memory and performs about group_size*2^(prev_bits+1) extra operations,
 | 
			
		||||
        so it's useful up to three passes but probably not four.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        assert size <= self.max_size and size % self.group_size == 0
 | 
			
		||||
@ -316,6 +454,22 @@ class Sorter(object):
 | 
			
		||||
        self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
 | 
			
		||||
            block=(512, 1, 1), grid=(grids, 1), stream=stream)
 | 
			
		||||
 | 
			
		||||
        # Intentionally ignore prev_bits=0
 | 
			
		||||
        if prev_bits:
 | 
			
		||||
            self.warn()
 | 
			
		||||
            assert not (size & (size - 1)), \
 | 
			
		||||
                    'Size must be a power of two, due to my enduring laziness'
 | 
			
		||||
            didx = cuda.mem_alloc(4 << prev_bits)
 | 
			
		||||
            mask = np.uint32(((1 << prev_bits) - 1) << prev_lo_bit)
 | 
			
		||||
            self.binary_search(didx, src, np.int32(prev_lo_bit),
 | 
			
		||||
                               mask, np.int32(size),
 | 
			
		||||
                               block=(128, 1, 1), grid=((1<<prev_bits)/128,1))
 | 
			
		||||
 | 
			
		||||
            grid=(1 << min(prev_bits, 15), 1 << max(0, prev_bits-15))
 | 
			
		||||
            self.prefix_scan_repair(self.doffsets, self.dpfxs, src, didx,
 | 
			
		||||
                    np.int32(lo_bit), np.int32(prev_lo_bit), mask,
 | 
			
		||||
                    block=(512, 1, 1), grid=grid, stream=stream)
 | 
			
		||||
 | 
			
		||||
        self.calc_local_pfxs(self.dlocals, self.dpfxs,
 | 
			
		||||
            block=(32, 1, 1), grid=(grids / 32, 1), stream=stream)
 | 
			
		||||
 | 
			
		||||
@ -333,59 +487,88 @@ class Sorter(object):
 | 
			
		||||
            self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
 | 
			
		||||
            block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
 | 
			
		||||
 | 
			
		||||
    def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,
 | 
			
		||||
                  rounds=1, stream=None):
 | 
			
		||||
        """
 | 
			
		||||
        Sort 'src', using scratch buffers 'scratch_a' and 'scratch_b' to hold
 | 
			
		||||
        the output of intermediate stages. Return whichever of the scratch
 | 
			
		||||
        buffers holds the final sorted data.
 | 
			
		||||
 | 
			
		||||
        It is okay to pass the same array for 'src' and 'scratch_b'.
 | 
			
		||||
        Otherwise, 'src' won't be touched.
 | 
			
		||||
        """
 | 
			
		||||
        if rounds > 1:
 | 
			
		||||
            self.warn()
 | 
			
		||||
        for i in range(rounds):
 | 
			
		||||
            b = i * self.radix_bits
 | 
			
		||||
            self.sort(scratch_a, src, size, lo_bit + b, lo_bit, b, stream)
 | 
			
		||||
            scratch_a, scratch_b, src = scratch_b, scratch_a, scratch_a
 | 
			
		||||
        return src
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def test(cls, count, correctness=False):
 | 
			
		||||
        keys = np.uint32(np.random.randint(0, 1<<cls.radix_bits, size=count))
 | 
			
		||||
        dkeys = cuda.to_device(keys)
 | 
			
		||||
        dout = cuda.mem_alloc(count * 4)
 | 
			
		||||
        dout_a = cuda.mem_alloc(count * 4)
 | 
			
		||||
        dout_b = cuda.mem_alloc(count * 4)
 | 
			
		||||
 | 
			
		||||
        sorter = cls(count)
 | 
			
		||||
        stream = cuda.Stream()
 | 
			
		||||
 | 
			
		||||
        def test_stub(shift):
 | 
			
		||||
            for i in range(10):
 | 
			
		||||
                evt_a = cuda.Event().record(stream)
 | 
			
		||||
                sorter.sort(dout, dkeys, count, shift, stream=stream)
 | 
			
		||||
                evt_b = cuda.Event().record(stream)
 | 
			
		||||
                evt_b.synchronize()
 | 
			
		||||
                dur = evt_b.time_since(evt_a) / 1000.
 | 
			
		||||
        def test_stub(shift, trials=10, rounds=1):
 | 
			
		||||
            # Run once so that evt_a doesn't include initialization time
 | 
			
		||||
            sorter.multisort(dout_a, dout_b, dkeys, count, shift,
 | 
			
		||||
                             rounds, stream=stream)
 | 
			
		||||
            evt_a = cuda.Event().record(stream)
 | 
			
		||||
            for i in range(trials):
 | 
			
		||||
                buf = sorter.multisort(dout_a, dout_b, dkeys, count, shift,
 | 
			
		||||
                             rounds, stream=stream)
 | 
			
		||||
            evt_b = cuda.Event().record(stream)
 | 
			
		||||
            evt_b.synchronize()
 | 
			
		||||
            dur = evt_b.time_since(evt_a) / (rounds * trials)
 | 
			
		||||
            print '%6.1f,\t%4.0f,\t%4.0f' % (dur, count / (dur * 1000),
 | 
			
		||||
                    count * sorter.radix_bits / (dur * 32 * 1000))
 | 
			
		||||
 | 
			
		||||
                print ( '  Overall time: %g secs'
 | 
			
		||||
                        '\t%g %d-bit keys/sec\t%g 32-bit keys/sec') % (
 | 
			
		||||
                        dur, count/dur, sorter.radix_bits,
 | 
			
		||||
                        count * sorter.radix_bits / (dur * 32) )
 | 
			
		||||
            if shift == 0 and correctness:
 | 
			
		||||
                print '\nTesting correctness'
 | 
			
		||||
                out = cuda.from_device(buf, (count,), np.uint32)
 | 
			
		||||
                sort = np.sort(keys)
 | 
			
		||||
                if np.all(out == sort):
 | 
			
		||||
                    print 'Correct'
 | 
			
		||||
                else:
 | 
			
		||||
                    nz = np.nonzero(out != sort)[0]
 | 
			
		||||
                    print sorted(set(nz >> 13))
 | 
			
		||||
                    for i in nz:
 | 
			
		||||
                        print i, out[i-1:i+2], sort[i-1:i+2]
 | 
			
		||||
                    assert False, 'Oh no'
 | 
			
		||||
 | 
			
		||||
        print '\n\n%d bit sort' % cls.radix_bits
 | 
			
		||||
        print 'Testing speed'
 | 
			
		||||
        test_stub(0)
 | 
			
		||||
 | 
			
		||||
        if '-s' not in sys.argv:
 | 
			
		||||
            print '\nTesting correctness'
 | 
			
		||||
            out = cuda.from_device(dout, (count,), np.uint32)
 | 
			
		||||
            sort = np.sort(keys)
 | 
			
		||||
            if np.all(out == sort):
 | 
			
		||||
                print 'Correct'
 | 
			
		||||
            else:
 | 
			
		||||
                assert False, 'Oh no'
 | 
			
		||||
 | 
			
		||||
        print '\nTesting speed at shifts'
 | 
			
		||||
        for b in range(cls.radix_bits - 1):
 | 
			
		||||
            print 'Performance with %d sig bits' % (cls.radix_bits - b)
 | 
			
		||||
        for b in range(cls.radix_bits - 3):
 | 
			
		||||
            print '%2d (%2d sig bits),\t' % (cls.radix_bits, cls.radix_bits - b),
 | 
			
		||||
            test_stub(b)
 | 
			
		||||
 | 
			
		||||
        if not correctness:
 | 
			
		||||
            for r in range(2,3):
 | 
			
		||||
                keys[:] = np.uint32(
 | 
			
		||||
                        np.random.randint(0, 1<<(cls.radix_bits*r), count))
 | 
			
		||||
                cuda.memcpy_htod(dkeys, keys)
 | 
			
		||||
                print '%2d x %d,\t\t\t' % (cls.radix_bits, r),
 | 
			
		||||
                test_stub(0, rounds=r)
 | 
			
		||||
        print
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    import sys
 | 
			
		||||
    import pycuda.autoinit
 | 
			
		||||
 | 
			
		||||
    np.set_printoptions(precision=5, edgeitems=20,
 | 
			
		||||
                        linewidth=100, threshold=90)
 | 
			
		||||
    count = 1 << 26
 | 
			
		||||
    np.set_printoptions(precision=5, edgeitems=200,
 | 
			
		||||
                        linewidth=95, threshold=9000)
 | 
			
		||||
    count = 1 << 25
 | 
			
		||||
 | 
			
		||||
    np.random.seed(42)
 | 
			
		||||
 | 
			
		||||
    correct = '-s' not in sys.argv
 | 
			
		||||
    correct = '-c' in sys.argv
 | 
			
		||||
    for g in (8192, 4096):
 | 
			
		||||
        print '\n\n== GROUP SIZE %d ==' % g
 | 
			
		||||
        print '\n\n== GROUP SIZE %d ==,\t  msec,\tMK/s,\tMK/s norm' % g
 | 
			
		||||
        Sorter.group_size = g
 | 
			
		||||
        for b in [7,8,9,10]:
 | 
			
		||||
            if g == 4096 and b == 10: continue
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user