mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Generalize the sort.
This commit is contained in:
parent
3147fd40d2
commit
13842196ea
@ -11,34 +11,56 @@ _CODE = tempita.Template(r"""
|
|||||||
#define GRP_RDX_FACTOR (GRPSZ / RDXSZ)
|
#define GRP_RDX_FACTOR (GRPSZ / RDXSZ)
|
||||||
#define GRP_BLK_FACTOR (GRPSZ / BLKSZ)
|
#define GRP_BLK_FACTOR (GRPSZ / BLKSZ)
|
||||||
#define GRPSZ {{group_size}}
|
#define GRPSZ {{group_size}}
|
||||||
|
#define RBITS {{radix_bits}}
|
||||||
#define RDXSZ {{radix_size}}
|
#define RDXSZ {{radix_size}}
|
||||||
#define BLKSZ 512
|
#define BLKSZ 512
|
||||||
|
|
||||||
|
#define get_radix(r, k, l) \
|
||||||
|
asm("bfe.u32 %0, %1, %2, {{radix_bits}};" : "=r"(r) : "r"(k), "r"(l))
|
||||||
|
|
||||||
// TODO: experiment with different block / group sizes
|
// TODO: experiment with different block / group sizes
|
||||||
__global__
|
__global__
|
||||||
void prefix_scan_8_0(
|
void prefix_scan(
|
||||||
int *offsets,
|
int *offsets,
|
||||||
int *pfxs,
|
int *pfxs,
|
||||||
const unsigned int *keys
|
const unsigned int *keys,
|
||||||
|
const int lo_bit
|
||||||
) {
|
) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
__shared__ int shr_pfxs[RDXSZ];
|
__shared__ int shr_pfxs[RDXSZ];
|
||||||
|
|
||||||
|
{{if radix_size <= 512}}
|
||||||
if (tid < RDXSZ) shr_pfxs[tid] = 0;
|
if (tid < RDXSZ) shr_pfxs[tid] = 0;
|
||||||
__syncthreads();
|
{{else}}
|
||||||
int i = tid + GRPSZ * blockIdx.x;
|
{{for i in range(0, radix_size, 512)}}
|
||||||
|
shr_pfxs[tid+{{i}}] = 0;
|
||||||
|
{{endfor}}
|
||||||
|
{{endif}}
|
||||||
|
|
||||||
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
|
__syncthreads();
|
||||||
|
int idx = tid + GRPSZ * blockIdx.x;
|
||||||
|
|
||||||
|
for (int i = 0; i < GRP_BLK_FACTOR; i++) {
|
||||||
// TODO: load 2 at once, compute, use a BFI to pack the two offsets
|
// TODO: load 2 at once, compute, use a BFI to pack the two offsets
|
||||||
// into an int to halve storage / bandwidth
|
// into an int to halve storage / bandwidth
|
||||||
// TODO: separate or integrated loop vars? unrolling?
|
// TODO: separate or integrated loop vars? unrolling?
|
||||||
int radix = keys[i] & 0xff;
|
int key = keys[idx];
|
||||||
offsets[i] = atomicAdd(shr_pfxs + radix, 1);
|
int radix;
|
||||||
i += BLKSZ;
|
get_radix(radix, key, lo_bit);
|
||||||
|
offsets[idx] = atomicAdd(shr_pfxs + radix, 1);
|
||||||
|
idx += BLKSZ;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
{{if radix_size <= 512}}
|
||||||
if (tid < RDXSZ) pfxs[tid + RDXSZ * blockIdx.x] = shr_pfxs[tid];
|
if (tid < RDXSZ) pfxs[tid + RDXSZ * blockIdx.x] = shr_pfxs[tid];
|
||||||
|
{{else}}
|
||||||
|
{{for i in range(0, radix_size, 512)}}
|
||||||
|
pfxs[tid + {{i}} + RDXSZ * blockIdx.x] = shr_pfxs[tid + {{i}}];
|
||||||
|
{{endfor}}
|
||||||
|
{{endif}}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate group-local exclusive prefix sums (the number of keys in the
|
// Calculate group-local exclusive prefix sums (the number of keys in the
|
||||||
@ -67,10 +89,10 @@ void calc_local_pfxs(
|
|||||||
// might be better to halve the chunk size and lose some coalescing
|
// might be better to halve the chunk size and lose some coalescing
|
||||||
// efficiency; need to benchmark. It's a relatively cheap step, though.
|
// efficiency; need to benchmark. It's a relatively cheap step, though.
|
||||||
|
|
||||||
for (int j = 0; j < 8; j++) {
|
for (int j = 0; j < RDXSZ / 32; j++) {
|
||||||
int jj = j << 5;
|
int jj = j << 5;
|
||||||
for (int i = 0; i < 32; i++) {
|
for (int i = 0; i < 32; i++) {
|
||||||
int base_offset = (i << 8) + jj + base + tid;
|
int base_offset = (i << RBITS) + jj + base + tid;
|
||||||
int swap_offset = (i << 5) + ((i + tid) & 0x1f);
|
int swap_offset = (i << 5) + ((i + tid) & 0x1f);
|
||||||
swap[swap_offset] = pfxs[base_offset];
|
swap[swap_offset] = pfxs[base_offset];
|
||||||
}
|
}
|
||||||
@ -84,7 +106,7 @@ void calc_local_pfxs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < 32; i++) {
|
for (int i = 0; i < 32; i++) {
|
||||||
int base_offset = (i << 8) + jj + base + tid;
|
int base_offset = (i << RBITS) + jj + base + tid;
|
||||||
int swap_offset = (i << 5) + ((i + tid) & 0x1f);
|
int swap_offset = (i << 5) + ((i + tid) & 0x1f);
|
||||||
locals[base_offset] = swap[swap_offset];
|
locals[base_offset] = swap[swap_offset];
|
||||||
}
|
}
|
||||||
@ -194,14 +216,15 @@ void radix_sort_direct(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#undef BLKSZ
|
#undef BLKSZ
|
||||||
#define BLKSZ 1024
|
#define BLKSZ {{group_size / 8}}
|
||||||
__global__
|
__global__
|
||||||
void radix_sort(
|
void radix_sort(
|
||||||
int *sorted_keys,
|
int *sorted_keys,
|
||||||
const int *keys,
|
const int *keys,
|
||||||
const int *offsets,
|
const int *offsets,
|
||||||
const int *pfxs,
|
const int *pfxs,
|
||||||
const int *locals
|
const int *locals,
|
||||||
|
const int lo_bit
|
||||||
) {
|
) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int blk_offset = GRPSZ * blockIdx.x;
|
const int blk_offset = GRPSZ * blockIdx.x;
|
||||||
@ -214,7 +237,8 @@ void radix_sort(
|
|||||||
|
|
||||||
for (int i = tid; i < GRPSZ; i += BLKSZ) {
|
for (int i = tid; i < GRPSZ; i += BLKSZ) {
|
||||||
int key = keys[i+blk_offset];
|
int key = keys[i+blk_offset];
|
||||||
int radix = key & 0xff;
|
int radix;
|
||||||
|
get_radix(radix, key, lo_bit);
|
||||||
int offset = offsets[i+blk_offset] + shr_offs[radix];
|
int offset = offsets[i+blk_offset] + shr_offs[radix];
|
||||||
defer[offset] = key;
|
defer[offset] = key;
|
||||||
}
|
}
|
||||||
@ -227,7 +251,8 @@ void radix_sort(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
|
for (int j = 0; j < GRP_BLK_FACTOR; j++) {
|
||||||
int key = defer[i];
|
int key = defer[i];
|
||||||
int radix = key & 0xff;
|
int radix;
|
||||||
|
get_radix(radix, key, lo_bit);
|
||||||
int offset = shr_offs[radix] + i;
|
int offset = shr_offs[radix] + i;
|
||||||
sorted_keys[offset] = key;
|
sorted_keys[offset] = key;
|
||||||
i += BLKSZ;
|
i += BLKSZ;
|
||||||
@ -238,15 +263,19 @@ void radix_sort(
|
|||||||
class Sorter(object):
|
class Sorter(object):
|
||||||
mod = None
|
mod = None
|
||||||
group_size = 8192
|
group_size = 8192
|
||||||
radix_size = 256
|
radix_bits = 8
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_mod(cls):
|
def init_mod(cls):
|
||||||
if cls.mod is None:
|
if cls.__dict__.get('mod') is None:
|
||||||
|
cls.radix_size = 1 << cls.radix_bits
|
||||||
code = _CODE.substitute(group_size=cls.group_size,
|
code = _CODE.substitute(group_size=cls.group_size,
|
||||||
radix_size=cls.radix_size)
|
radix_bits=cls.radix_bits, radix_size=cls.radix_size)
|
||||||
cls.mod = pycuda.compiler.SourceModule(code)
|
cubin = pycuda.compiler.compile(code)
|
||||||
for name in ['prefix_scan_8_0', 'prefix_sum_condense',
|
cls.mod = cuda.module_from_buffer(cubin)
|
||||||
|
with open('/tmp/sort_kern.cubin', 'wb') as fp:
|
||||||
|
fp.write(cubin)
|
||||||
|
for name in ['prefix_scan', 'prefix_sum_condense',
|
||||||
'prefix_sum_inner', 'prefix_sum_distribute']:
|
'prefix_sum_inner', 'prefix_sum_distribute']:
|
||||||
f = cls.mod.get_function(name)
|
f = cls.mod.get_function(name)
|
||||||
setattr(cls, name, f)
|
setattr(cls, name, f)
|
||||||
@ -254,16 +283,15 @@ class Sorter(object):
|
|||||||
cls.calc_local_pfxs = cls.mod.get_function('calc_local_pfxs')
|
cls.calc_local_pfxs = cls.mod.get_function('calc_local_pfxs')
|
||||||
cls.radix_sort = cls.mod.get_function('radix_sort')
|
cls.radix_sort = cls.mod.get_function('radix_sort')
|
||||||
|
|
||||||
def __init__(self, size, dst=None):
|
def __init__(self, max_size):
|
||||||
self.init_mod()
|
self.init_mod()
|
||||||
assert size % self.group_size == 0, 'bad multiple'
|
self.max_size = max_size
|
||||||
if dst is None:
|
assert max_size % self.group_size == 0
|
||||||
dst = cuda.mem_alloc(size * 4)
|
max_grids = max_size / self.group_size
|
||||||
self.size, self.dst = size, dst
|
|
||||||
self.doffsets = cuda.mem_alloc(self.size * 4)
|
self.doffsets = cuda.mem_alloc(self.max_size * 4)
|
||||||
self.grids = self.size / self.group_size
|
self.dpfxs = cuda.mem_alloc(max_grids * self.radix_size * 4)
|
||||||
self.dpfxs = cuda.mem_alloc(self.grids * self.radix_size * 4)
|
self.dlocals = cuda.mem_alloc(max_grids * self.radix_size * 4)
|
||||||
self.dlocals = cuda.mem_alloc(self.grids * self.radix_size * 4)
|
|
||||||
|
|
||||||
# There are probably better ways to choose how many condensation
|
# There are probably better ways to choose how many condensation
|
||||||
# groups to launch. TODO: maybe pick one if I care
|
# groups to launch. TODO: maybe pick one if I care
|
||||||
@ -271,15 +299,28 @@ class Sorter(object):
|
|||||||
self.dcond = cuda.mem_alloc(self.radix_size * self.ncond * 4)
|
self.dcond = cuda.mem_alloc(self.radix_size * self.ncond * 4)
|
||||||
self.dglobal = cuda.mem_alloc(self.radix_size * 4)
|
self.dglobal = cuda.mem_alloc(self.radix_size * 4)
|
||||||
|
|
||||||
def sort(self, src, stream=None):
|
def sort(self, dst, src, size, lo_bit=0, stream=None):
|
||||||
self.prefix_scan_8_0(self.doffsets, self.dpfxs, src,
|
"""
|
||||||
block=(512, 1, 1), grid=(self.grids, 1), stream=stream)
|
Sort 'src' by the bits from lo_bit+radix_bits to lo_bit, where 0 is
|
||||||
|
the LSB. Store the result in 'dst'.
|
||||||
|
|
||||||
|
Note that this is *not* a stable sort! It won't jumble your data
|
||||||
|
haphazardly, but one- or two-position swaps are very common. This will
|
||||||
|
hopefully be resolved soon, but until then, it is unsuitable for
|
||||||
|
implementing larger sorts from multiple passes of this sort.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert size <= self.max_size and size % self.group_size == 0
|
||||||
|
grids = size / self.group_size
|
||||||
|
|
||||||
|
self.prefix_scan(self.doffsets, self.dpfxs, src, np.int32(lo_bit),
|
||||||
|
block=(512, 1, 1), grid=(grids, 1), stream=stream)
|
||||||
|
|
||||||
self.calc_local_pfxs(self.dlocals, self.dpfxs,
|
self.calc_local_pfxs(self.dlocals, self.dpfxs,
|
||||||
block=(32, 1, 1), grid=(self.grids / 32, 1), stream=stream)
|
block=(32, 1, 1), grid=(grids / 32, 1), stream=stream)
|
||||||
|
|
||||||
ngrps = np.int32(self.grids)
|
ngrps = np.int32(grids)
|
||||||
grpwidth = np.int32(np.ceil(float(self.grids) / self.ncond))
|
grpwidth = np.int32(np.ceil(float(grids) / self.ncond))
|
||||||
|
|
||||||
self.prefix_sum_condense(self.dcond, self.dpfxs, ngrps, grpwidth,
|
self.prefix_sum_condense(self.dcond, self.dpfxs, ngrps, grpwidth,
|
||||||
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
||||||
@ -288,35 +329,67 @@ class Sorter(object):
|
|||||||
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
|
self.prefix_sum_distribute(self.dpfxs, self.dcond, ngrps, grpwidth,
|
||||||
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
block=(self.radix_size, 1, 1), grid=(self.ncond, 1), stream=stream)
|
||||||
|
|
||||||
self.radix_sort(self.dst, src, self.doffsets, self.dpfxs, self.dlocals,
|
self.radix_sort(dst, src,
|
||||||
block=(1024, 1, 1), grid=(self.grids, 1), stream=stream)
|
self.doffsets, self.dpfxs, self.dlocals, np.int32(lo_bit),
|
||||||
|
block=(self.group_size / 8, 1, 1), grid=(grids, 1), stream=stream)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def test(cls, count, correctness=False):
|
||||||
|
keys = np.uint32(np.random.randint(0, 1<<cls.radix_bits, size=count))
|
||||||
|
dkeys = cuda.to_device(keys)
|
||||||
|
dout = cuda.mem_alloc(count * 4)
|
||||||
|
|
||||||
|
sorter = cls(count)
|
||||||
|
stream = cuda.Stream()
|
||||||
|
|
||||||
|
def test_stub(shift):
|
||||||
|
for i in range(10):
|
||||||
|
evt_a = cuda.Event().record(stream)
|
||||||
|
sorter.sort(dout, dkeys, count, shift, stream=stream)
|
||||||
|
evt_b = cuda.Event().record(stream)
|
||||||
|
evt_b.synchronize()
|
||||||
|
dur = evt_b.time_since(evt_a) / 1000.
|
||||||
|
|
||||||
|
print ( ' Overall time: %g secs'
|
||||||
|
'\t%g %d-bit keys/sec\t%g 32-bit keys/sec') % (
|
||||||
|
dur, count/dur, sorter.radix_bits,
|
||||||
|
count * sorter.radix_bits / (dur * 32) )
|
||||||
|
|
||||||
|
print '\n\n%d bit sort' % cls.radix_bits
|
||||||
|
print 'Testing speed'
|
||||||
|
test_stub(0)
|
||||||
|
|
||||||
|
if '-s' not in sys.argv:
|
||||||
|
print '\nTesting correctness'
|
||||||
|
out = cuda.from_device(dout, (count,), np.uint32)
|
||||||
|
sort = np.sort(keys)
|
||||||
|
if np.all(out == sort):
|
||||||
|
print 'Correct'
|
||||||
|
else:
|
||||||
|
assert False, 'Oh no'
|
||||||
|
|
||||||
|
print '\nTesting speed at shifts'
|
||||||
|
for b in range(cls.radix_bits - 1):
|
||||||
|
print 'Performance with %d sig bits' % (cls.radix_bits - b)
|
||||||
|
test_stub(b)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
import pycuda.autoinit
|
import pycuda.autoinit
|
||||||
|
|
||||||
np.set_printoptions(precision=5, edgeitems=20,
|
np.set_printoptions(precision=5, edgeitems=20,
|
||||||
linewidth=100, threshold=90)
|
linewidth=100, threshold=90)
|
||||||
count = 1 << 26
|
count = 1 << 26
|
||||||
|
|
||||||
keys = np.uint32(np.fromstring(np.random.bytes(count), dtype=np.uint8))
|
np.random.seed(42)
|
||||||
dkeys = cuda.to_device(keys)
|
|
||||||
|
|
||||||
sorter = Sorter(count)
|
correct = '-s' not in sys.argv
|
||||||
|
for g in (8192, 4096):
|
||||||
print 'Testing speed'
|
print '\n\n== GROUP SIZE %d ==' % g
|
||||||
stream = cuda.Stream()
|
Sorter.group_size = g
|
||||||
for i in range(10):
|
for b in [7,8,9,10]:
|
||||||
evt_a = cuda.Event().record(stream)
|
if g == 4096 and b == 10: continue
|
||||||
sorter.sort(dkeys, stream)
|
Sorter.radix_bits = b
|
||||||
evt_b = cuda.Event().record(stream)
|
Sorter.test(count, correct)
|
||||||
evt_b.synchronize()
|
del Sorter.mod
|
||||||
dur = evt_b.time_since(evt_a)
|
|
||||||
print 'Overall time: %g secs (%g 8-bit keys/sec)' % (
|
|
||||||
dur / 1000., 1000 * count / dur)
|
|
||||||
|
|
||||||
|
|
||||||
print 'Testing correctness'
|
|
||||||
out = cuda.from_device(sorter.dst, (count,), np.uint32)
|
|
||||||
sort = np.sort(keys)
|
|
||||||
print 'Sorted correctly?', np.all(out == sort)
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user