mirror of
https://github.com/stevenrobertson/cuburn.git
synced 2025-02-05 11:40:04 -05:00
Make store_per_thread reuse gtid in multiple calls when possible
This commit is contained in:
parent
943e92b80c
commit
e71a8422e5
@ -132,8 +132,8 @@ class IterThread(PTXEntryPoint):
|
|||||||
|
|
||||||
label('all_cps_done')
|
label('all_cps_done')
|
||||||
# TODO this is for testing, move it to a debug statement
|
# TODO this is for testing, move it to a debug statement
|
||||||
std.store_per_thread(g_num_rounds, num_rounds)
|
std.store_per_thread(g_num_rounds, num_rounds,
|
||||||
std.store_per_thread(g_num_writes, num_writes)
|
g_num_writes, num_writes)
|
||||||
|
|
||||||
@instmethod
|
@instmethod
|
||||||
def upload_cp_stream(self, ctx, cp_stream, num_cps):
|
def upload_cp_stream(self, ctx, cp_stream, num_cps):
|
||||||
|
@ -689,18 +689,22 @@ class _PTXStdLib(PTXFragment):
|
|||||||
op.mad.lo.u32(dst, cta, ncta, tid)
|
op.mad.lo.u32(dst, cta, ncta, tid)
|
||||||
|
|
||||||
@ptx_func
|
@ptx_func
|
||||||
def store_per_thread(self, base, val):
|
def store_per_thread(self, *args):
|
||||||
"""Store b32 at `base+gtid*4`. Super-common debug pattern."""
|
"""Store b32 at `base+gtid*4`. Super-common debug pattern."""
|
||||||
with block("Per-thread store of %s" % str(val)):
|
with block("Per-thread storing values"):
|
||||||
reg.u32('spt_base spt_offset')
|
reg.u32('spt_base spt_offset')
|
||||||
op.mov.u32(spt_base, base)
|
|
||||||
self.get_gtid(spt_offset)
|
self.get_gtid(spt_offset)
|
||||||
op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
|
op.mul.lo.u32(spt_offset, spt_offset, 4)
|
||||||
if isinstance(val, float):
|
for i in range(0, len(args), 2):
|
||||||
# Turn a constant float into the big-endian PTX binary float
|
base, val = args[i], args[i+1]
|
||||||
# representation, 0fXXXXXXXX (where XX is hex byte)
|
op.mov.u32(spt_base, base)
|
||||||
val = '0f%x%x%x%x' % reversed(map(ord, struct.pack('f', val)))
|
op.add.u32(spt_base, spt_base, spt_offset)
|
||||||
op.st.b32(addr(spt_base), val)
|
if isinstance(val, float):
|
||||||
|
# Turn a constant float into the big-endian PTX binary f32
|
||||||
|
# representation, 0fXXXXXXXX (where XX is hex byte)
|
||||||
|
val = '0f%x%x%x%x' % reversed(map(ord,
|
||||||
|
struct.pack('f', val)))
|
||||||
|
op.st.b32(addr(spt_base), val)
|
||||||
|
|
||||||
@ptx_func
|
@ptx_func
|
||||||
def set_is_first_thread(self, p_dst):
|
def set_is_first_thread(self, p_dst):
|
||||||
|
Loading…
Reference in New Issue
Block a user