mirror of
				https://github.com/stevenrobertson/cuburn.git
				synced 2025-11-03 18:00:55 -05:00 
			
		
		
		
	Make store_per_thread reuse gtid in multiple calls when possible
This commit is contained in:
		@ -132,8 +132,8 @@ class IterThread(PTXEntryPoint):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        label('all_cps_done')
 | 
					        label('all_cps_done')
 | 
				
			||||||
        # TODO this is for testing, move it to a debug statement
 | 
					        # TODO this is for testing, move it to a debug statement
 | 
				
			||||||
        std.store_per_thread(g_num_rounds, num_rounds)
 | 
					        std.store_per_thread(g_num_rounds, num_rounds,
 | 
				
			||||||
        std.store_per_thread(g_num_writes, num_writes)
 | 
					                             g_num_writes, num_writes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @instmethod
 | 
					    @instmethod
 | 
				
			||||||
    def upload_cp_stream(self, ctx, cp_stream, num_cps):
 | 
					    def upload_cp_stream(self, ctx, cp_stream, num_cps):
 | 
				
			||||||
 | 
				
			|||||||
@ -689,17 +689,21 @@ class _PTXStdLib(PTXFragment):
 | 
				
			|||||||
            op.mad.lo.u32(dst, cta, ncta, tid)
 | 
					            op.mad.lo.u32(dst, cta, ncta, tid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @ptx_func
 | 
					    @ptx_func
 | 
				
			||||||
    def store_per_thread(self, base, val):
 | 
					    def store_per_thread(self, *args):
 | 
				
			||||||
        """Store b32 at `base+gtid*4`. Super-common debug pattern."""
 | 
					        """Store b32 at `base+gtid*4`. Super-common debug pattern."""
 | 
				
			||||||
        with block("Per-thread store of %s" % str(val)):
 | 
					        with block("Per-thread storing values"):
 | 
				
			||||||
            reg.u32('spt_base spt_offset')
 | 
					            reg.u32('spt_base spt_offset')
 | 
				
			||||||
            op.mov.u32(spt_base, base)
 | 
					 | 
				
			||||||
            self.get_gtid(spt_offset)
 | 
					            self.get_gtid(spt_offset)
 | 
				
			||||||
            op.mad.lo.u32(spt_base, spt_offset, 4, spt_base)
 | 
					            op.mul.lo.u32(spt_offset, spt_offset, 4)
 | 
				
			||||||
 | 
					            for i in range(0, len(args), 2):
 | 
				
			||||||
 | 
					                base, val = args[i], args[i+1]
 | 
				
			||||||
 | 
					                op.mov.u32(spt_base, base)
 | 
				
			||||||
 | 
					                op.add.u32(spt_base, spt_base, spt_offset)
 | 
				
			||||||
                if isinstance(val, float):
 | 
					                if isinstance(val, float):
 | 
				
			||||||
                # Turn a constant float into the big-endian PTX binary float
 | 
					                    # Turn a constant float into the big-endian PTX binary f32
 | 
				
			||||||
                    # representation, 0fXXXXXXXX (where XX is hex byte)
 | 
					                    # representation, 0fXXXXXXXX (where XX is hex byte)
 | 
				
			||||||
                val = '0f%x%x%x%x' % reversed(map(ord, struct.pack('f', val)))
 | 
					                    val = '0f%x%x%x%x' % reversed(map(ord,
 | 
				
			||||||
 | 
					                                                      struct.pack('f', val)))
 | 
				
			||||||
                op.st.b32(addr(spt_base), val)
 | 
					                op.st.b32(addr(spt_base), val)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @ptx_func
 | 
					    @ptx_func
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user