diff --git a/cuburn/cuda.py b/cuburn/cuda.py index ecfc088..f7c9fca 100644 --- a/cuburn/cuda.py +++ b/cuburn/cuda.py @@ -110,3 +110,17 @@ class LaunchContext(object): all_okay = False return all_okay + def get_per_thread(self, name, dtype, shaped=False): + """ + Convenience function to get the contents of the global memory variable + ``name`` from the device as a numpy array of type ``dtype``, as might + be stored by _PTXStdLib.store_per_thread. If ``shaped`` is True, the + array will be 3D, as (cta_no, warp_no, lane_no). + """ + if shaped: + shape = (self.nctas, self.warps_per_cta, 32) + else: + shape = self.nthreads + dp, l = self.mod.get_global(name) + return cuda.from_device(dp, shape, dtype) +