Experiments with multi-pass sort (still has bugs)

This commit is contained in:
Steven Robertson 2011-11-10 10:49:35 -05:00
parent 13842196ea
commit 54f411878b

View File

@ -1,4 +1,6 @@
import warnings
import numpy as np import numpy as np
import pycuda.driver as cuda import pycuda.driver as cuda
import pycuda.compiler import pycuda.compiler
@ -18,7 +20,16 @@ _CODE = tempita.Template(r"""
#define get_radix(r, k, l) \ #define get_radix(r, k, l) \
asm("bfe.u32 %0, %1, %2, {{radix_bits}};" : "=r"(r) : "r"(k), "r"(l)) asm("bfe.u32 %0, %1, %2, {{radix_bits}};" : "=r"(r) : "r"(k), "r"(l))
// TODO: experiment with different block / group sizes // This kernel conducts a prefix scan of the 'keys' array. As each radix is
// read, it is immediately added to the corresponding accumulator. The
// resulting value is unique among keys in this block with the same radix. We
// will use this offset later (along with the more typical prefix sums) to
// insert the key into a shared memory array by radix, so that they can be
// written in fewer transactions to global memory (which is important given
// the larger radix sizes used here).
//
// Note that the indices generated here are unique but not necessarily
// monotonic, so using these directly leads to a mildly unstable sort.
__global__ __global__
void prefix_scan( void prefix_scan(
int *offsets, int *offsets,
@ -40,10 +51,10 @@ void prefix_scan(
__syncthreads(); __syncthreads();
int idx = tid + GRPSZ * blockIdx.x; int idx = tid + GRPSZ * blockIdx.x;
#pragma unroll
for (int i = 0; i < GRP_BLK_FACTOR; i++) { for (int i = 0; i < GRP_BLK_FACTOR; i++) {
// TODO: load 2 at once, compute, use a BFI to pack the two offsets // TODO: load 2 at once, compute, use a BFI to pack the two offsets
// into an int to halve storage / bandwidth // into an int to halve storage / bandwidth
// TODO: separate or integrated loop vars? unrolling?
int key = keys[idx]; int key = keys[idx];
int radix; int radix;
get_radix(radix, key, lo_bit); get_radix(radix, key, lo_bit);
@ -60,7 +71,108 @@ void prefix_scan(
pfxs[tid + {{i}} + RDXSZ * blockIdx.x] = shr_pfxs[tid + {{i}}]; pfxs[tid + {{i}} + RDXSZ * blockIdx.x] = shr_pfxs[tid + {{i}}];
{{endfor}} {{endfor}}
{{endif}} {{endif}}
}
// Populate 'indices' so that indices[i] contains the largest value 'x' such
// that keys[x] < i. If no value is smaller than i, indices[i] is 0.
__global__
void binary_search(
int *indices,
const unsigned int *keys,
const int prev_lo_bit,
const int mask,
const int length
) {
int tid_full = blockIdx.x * blockDim.x + threadIdx.x;
int target = tid_full << prev_lo_bit;
int lo = 0;
// Length must be a power of two! (Guaranteed by runtime.)
for (int i = length >> 1; i > 0; i >>= 1) {
int mid = lo | i;
if (target > (keys[mid] & mask)) lo = mid;
}
indices[tid_full] = lo;
}
// When performing a sort by repeatedly applying smaller sorts, the non-stable
// nature of the prefix scan done above will cause errors in the output. These
// errors only affect the correctness of the sort inside those groups which
// cover a transition in the radix of the previous sort passes. We re-run
// those groups with a more careful algorithm here. This doesn't make the sort
// stable in general, but it's enough to make multi-pass sorts correct.
// (Or it should be, although it seems there's a bug either in this code or in
// my head.)
__global__
void prefix_scan_repair(
int *offsets,
int *pfxs,
const unsigned int *keys,
const unsigned int *trans_points,
const int lo_bit,
const int prev_lo_bit,
const int mask
) {
const int tid = threadIdx.x;
const int blkid = blockIdx.y * gridDim.x + blockIdx.x;
__shared__ int shr_pfxs[RDXSZ];
__shared__ int blk_starts[GRP_BLK_FACTOR];
__shared__ int blk_transitions[GRP_BLK_FACTOR];
// Never need to repair the start of the array.
if (blkid == 0) return;
// Get the start index of the group to repair.
int grp_start = trans_points[blkid] & ~(GRPSZ - 1);
// If the largest prev_radix in this block is not equal to than our blkid,
// it means that another thread block will also attend to the same thread
// block (in which case we cede to it), or the transition point happens to
// be on a group boundary. In either case, we should bail.
//
// Note that prev_* holds a masked but not shifted value.
int prev_max = keys[grp_start + (GRPSZ - 1)] & mask;
if (prev_max != (blkid << prev_lo_bit)) return;
int prev_incr = 1 << prev_lo_bit;
// For each block of keys that this thread block will analyze, determine
// how many transitions occur within that block.
if (tid < GRP_BLK_FACTOR) {
int prev_lo = keys[grp_start + tid * BLKSZ] & mask;
int prev_hi = keys[grp_start + tid * BLKSZ + BLKSZ - 1] & mask;
blk_starts[tid] = prev_lo;
blk_transitions[tid] = (prev_hi - prev_lo) >> prev_lo_bit;
}
{{if radix_size <= 512}}
if (tid < RDXSZ) shr_pfxs[tid] = 0;
{{else}}
{{for i in range(0, radix_size, 512)}}
shr_pfxs[tid+{{i}}] = 0;
{{endfor}}
{{endif}}
__syncthreads();
int idx = grp_start + tid;
for (int i = 0; i < GRP_BLK_FACTOR; i++) {
int key = keys[idx];
int prev_radix = blk_starts[i];
int this_prev_radix = key & mask;
int radix;
get_radix(radix, key, lo_bit);
if (this_prev_radix == prev_radix)
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
for (int j = 0; j < blk_transitions[i]; j++) {
__syncthreads();
prev_radix += prev_incr;
if (this_prev_radix == prev_radix)
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
}
idx += BLKSZ;
}
} }
// Calculate group-local exclusive prefix sums (the number of keys in the // Calculate group-local exclusive prefix sums (the number of keys in the
@ -265,6 +377,8 @@ class Sorter(object):
group_size = 8192 group_size = 8192
radix_bits = 8 radix_bits = 8
warn_issued = False
@classmethod @classmethod
def init_mod(cls): def init_mod(cls):
if cls.__dict__.get('mod') is None: if cls.__dict__.get('mod') is None:
@ -276,20 +390,31 @@ class Sorter(object):
with open('/tmp/sort_kern.cubin', 'wb') as fp: with open('/tmp/sort_kern.cubin', 'wb') as fp:
fp.write(cubin) fp.write(cubin)
for name in ['prefix_scan', 'prefix_sum_condense', for name in ['prefix_scan', 'prefix_sum_condense',
'prefix_sum_inner', 'prefix_sum_distribute']: 'prefix_sum_inner', 'prefix_sum_distribute',
'binary_search', 'prefix_scan_repair']:
f = cls.mod.get_function(name) f = cls.mod.get_function(name)
setattr(cls, name, f) setattr(cls, name, f)
f.set_cache_config(cuda.func_cache.PREFER_L1) f.set_cache_config(cuda.func_cache.PREFER_L1)
cls.calc_local_pfxs = cls.mod.get_function('calc_local_pfxs') cls.calc_local_pfxs = cls.mod.get_function('calc_local_pfxs')
cls.radix_sort = cls.mod.get_function('radix_sort') cls.radix_sort = cls.mod.get_function('radix_sort')
def __init__(self, max_size): def __init__(self, max_size, offsets=None):
"""
Create a sorter. The sorter will hold on to internal resources for as
long as it is alive, including an 'offsets' array of size 4*max_size.
To share this cost, you may pass in an array of at least this size to
__init__ (to, for instance, share across different bit-widths in a
multi-pass sort).
"""
self.init_mod() self.init_mod()
self.max_size = max_size self.max_size = max_size
assert max_size % self.group_size == 0 assert max_size % self.group_size == 0
max_grids = max_size / self.group_size max_grids = max_size / self.group_size
if offsets is None:
self.doffsets = cuda.mem_alloc(self.max_size * 4) self.doffsets = cuda.mem_alloc(self.max_size * 4)
else:
self.doffsets = offsets
self.dpfxs = cuda.mem_alloc(max_grids * self.radix_size * 4) self.dpfxs = cuda.mem_alloc(max_grids * self.radix_size * 4)
self.dlocals = cuda.mem_alloc(max_grids * self.radix_size * 4) self.dlocals = cuda.mem_alloc(max_grids * self.radix_size * 4)
@ -299,15 +424,28 @@ class Sorter(object):
self.dcond = cuda.mem_alloc(self.radix_size * self.ncond * 4) self.dcond = cuda.mem_alloc(self.radix_size * self.ncond * 4)
self.dglobal = cuda.mem_alloc(self.radix_size * 4) self.dglobal = cuda.mem_alloc(self.radix_size * 4)
def sort(self, dst, src, size, lo_bit=0, stream=None): def warn(self):
if not self.warn_issued:
warnings.warn('You know multi-pass is broken, right?',
RuntimeWarning, stacklevel=3)
self.warn_issued = True
def sort(self, dst, src, size, lo_bit=0,
prev_lo_bit=None, prev_bits=None, stream=None):
""" """
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
the LSB. Store the result in 'dst'. the LSB. Store the result in 'dst'.
Note that this is *not* a stable sort! It won't jumble your data GIANT ENORMOUS WARNING. The single pass sort is not quite stable, even
haphazardly, but one- or two-position swaps are very common. This will with the multi-pass repair kernel. This means that multi-pass sort is
hopefully be resolved soon, but until then, it is unsuitable for bugged. Don't use it unless your application can handle non-monotonic
implementing larger sorts from multiple passes of this sort. values in your "sorted" array.
To perform a multi-pass sort, pass 'prev_lo_bit' and 'prev_bits',
indicating the lowest bit considered across the entire sort and the
number of bits previously sorted. This uses 2^(prev_bits+2) bytes of
memory and performs about group_size*2^(prev_bits+1) extra operations,
so it's useful up to three passes but probably not four.
""" """
assert size <= self.max_size and size % self.group_size == 0 assert size <= self.max_size and size % self.group_size == 0
@ -316,6 +454,22 @@ class Sorter(object):
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit), self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
block=(512, 1, 1), grid=(grids, 1), stream=stream) block=(512, 1, 1), grid=(grids, 1), stream=stream)
# Intentionally ignore prev_bits=0
if prev_bits:
self.warn()
assert not (size & (size - 1)), \
'Size must be a power of two, due to my enduring laziness'
didx = cuda.mem_alloc(4 << prev_bits)
mask = np.uint32(((1 << prev_bits) - 1) << prev_lo_bit)
self.binary_search(didx, src, np.int32(prev_lo_bit),
mask, np.int32(size),
block=(128, 1, 1), grid=((1<<prev_bits)/128,1))
grid=(1 << min(prev_bits, 15), 1 << max(0, prev_bits-15))
self.prefix_scan_repair(self.doffsets, self.dpfxs, src, didx,
np.int32(lo_bit), np.int32(prev_lo_bit), mask,
block=(512, 1, 1), grid=grid, stream=stream)
self.calc_local_pfxs(self.dlocals, self.dpfxs, self.calc_local_pfxs(self.dlocals, self.dpfxs,
block=(32, 1, 1), grid=(grids / 32, 1), stream=stream) block=(32, 1, 1), grid=(grids / 32, 1), stream=stream)
@ -333,59 +487,88 @@ class Sorter(object):
self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit), self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream) block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
def multisort(self, scratch_a, scratch_b, src, size, lo_bit=0,
rounds=1, stream=None):
"""
Sort 'src', using scratch buffers 'scratch_a' and 'scratch_b' to hold
the output of intermediate stages. Return whichever of the scratch
buffers holds the final sorted data.
It is okay to pass the same array for 'src' and 'scratch_b'.
Otherwise, 'src' won't be touched.
"""
if rounds > 1:
self.warn()
for i in range(rounds):
b = i * self.radix_bits
self.sort(scratch_a, src, size, lo_bit + b, lo_bit, b, stream)
scratch_a, scratch_b, src = scratch_b, scratch_a, scratch_a
return src
@classmethod @classmethod
def test(cls, count, correctness=False): def test(cls, count, correctness=False):
keys = np.uint32(np.random.randint(0, 1<<cls.radix_bits, size=count)) keys = np.uint32(np.random.randint(0, 1<<cls.radix_bits, size=count))
dkeys = cuda.to_device(keys) dkeys = cuda.to_device(keys)
dout = cuda.mem_alloc(count * 4) dout_a = cuda.mem_alloc(count * 4)
dout_b = cuda.mem_alloc(count * 4)
sorter = cls(count) sorter = cls(count)
stream = cuda.Stream() stream = cuda.Stream()
def test_stub(shift): def test_stub(shift, trials=10, rounds=1):
for i in range(10): # Run once so that evt_a doesn't include initialization time
sorter.multisort(dout_a, dout_b, dkeys, count, shift,
rounds, stream=stream)
evt_a = cuda.Event().record(stream) evt_a = cuda.Event().record(stream)
sorter.sort(dout, dkeys, count, shift, stream=stream) for i in range(trials):
buf = sorter.multisort(dout_a, dout_b, dkeys, count, shift,
rounds, stream=stream)
evt_b = cuda.Event().record(stream) evt_b = cuda.Event().record(stream)
evt_b.synchronize() evt_b.synchronize()
dur = evt_b.time_since(evt_a) / 1000. dur = evt_b.time_since(evt_a) / (rounds * trials)
print '%6.1f,\t%4.0f,\t%4.0f' % (dur, count / (dur * 1000),
count * sorter.radix_bits / (dur * 32 * 1000))
print ( ' Overall time: %g secs' if shift == 0 and correctness:
'\t%g %d-bit keys/sec\t%g 32-bit keys/sec') % (
dur, count/dur, sorter.radix_bits,
count * sorter.radix_bits / (dur * 32) )
print '\n\n%d bit sort' % cls.radix_bits
print 'Testing speed'
test_stub(0)
if '-s' not in sys.argv:
print '\nTesting correctness' print '\nTesting correctness'
out = cuda.from_device(dout, (count,), np.uint32) out = cuda.from_device(buf, (count,), np.uint32)
sort = np.sort(keys) sort = np.sort(keys)
if np.all(out == sort): if np.all(out == sort):
print 'Correct' print 'Correct'
else: else:
nz = np.nonzero(out != sort)[0]
print sorted(set(nz >> 13))
for i in nz:
print i, out[i-1:i+2], sort[i-1:i+2]
assert False, 'Oh no' assert False, 'Oh no'
print '\nTesting speed at shifts'
for b in range(cls.radix_bits - 1): for b in range(cls.radix_bits - 3):
print 'Performance with %d sig bits' % (cls.radix_bits - b) print '%2d (%2d sig bits),\t' % (cls.radix_bits, cls.radix_bits - b),
test_stub(b) test_stub(b)
if not correctness:
for r in range(2,3):
keys[:] = np.uint32(
np.random.randint(0, 1<<(cls.radix_bits*r), count))
cuda.memcpy_htod(dkeys, keys)
print '%2d x %d,\t\t\t' % (cls.radix_bits, r),
test_stub(0, rounds=r)
print
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
import pycuda.autoinit import pycuda.autoinit
np.set_printoptions(precision=5, edgeitems=20, np.set_printoptions(precision=5, edgeitems=200,
linewidth=100, threshold=90) linewidth=95, threshold=9000)
count = 1 << 26 count = 1 << 25
np.random.seed(42) np.random.seed(42)
correct = '-s' not in sys.argv correct = '-c' in sys.argv
for g in (8192, 4096): for g in (8192, 4096):
print '\n\n== GROUP SIZE %d ==' % g print '\n\n== GROUP SIZE %d ==,\t msec,\tMK/s,\tMK/s norm' % g
Sorter.group_size = g Sorter.group_size = g
for b in [7,8,9,10]: for b in [7,8,9,10]:
if g == 4096 and b == 10: continue if g == 4096 and b == 10: continue