diff --git a/cuburnlib/ptx.py b/cuburnlib/ptx.py index 5a18b2c..0ebaa00 100644 --- a/cuburnlib/ptx.py +++ b/cuburnlib/ptx.py @@ -605,6 +605,7 @@ class PTXEntryPoint(PTXFragment): entry_name = "" # List of (type, name) pairs for entry params, e.g. [('u32', 'thing')] entry_params = [] + maxnreg = None def entry(self): """ @@ -817,9 +818,11 @@ class PTXModule(object): # This is kind of hackish compared to everything else params = [Reg('.param.' + str(type), name) for (type, name) in ent.entry_params] - _block.code(op='.entry %s ' % ent.entry_name, semi=False, + _block.code(op='.entry %s' % ent.entry_name, semi=False, vars=['(', ', '.join(['%s %s' % (r.type, r.name) for r in params]), ')']) + if ent.maxnreg: + _block.code(op='.maxnreg %d' % ent.maxnreg, semi=False) with Block(_block): [_block.inject(r.name, r) for r in params] for dep in insts: