Register filters with a class decorator

This commit is contained in:
Steven Robertson 2015-10-10 15:58:13 -07:00
parent 227a6016c2
commit 698d9c2337

View File

@ -28,6 +28,9 @@ def mkdsc(dim, ch):
format=cuda.array_format.FLOAT)
class Filter(object):
filter_map = {}
name = ''
# Set to True if the filter requires a full 4-channel side buffer
full_side = False
def apply(self, fb, gprof, params, dim, tc, stream=None):
@ -39,6 +42,15 @@ class Filter(object):
"""
raise NotImplementedError()
@classmethod
def register(cls, name):
def register_(subcls):
cls.filter_map[name] = subcls
subcls.name = name
return register_
@Filter.register('bilateral')
class Bilateral(Filter, ClsMod):
lib = code.filters.bilaterallib
radius = 15
@ -79,6 +91,7 @@ class Bilateral(Filter, ClsMod):
texrefs=[tref, grad_tref])
fb.flip()
@Filter.register('logscale')
class Logscale(Filter, ClsMod):
lib = code.filters.logscalelib
def apply(self, fb, gprof, params, dim, tc, stream=None):
@ -91,6 +104,7 @@ class Logscale(Filter, ClsMod):
launch2('logscale', self.mod, stream, dim,
fb.d_front, fb.d_front, k1, k2)
@Filter.register('haloclip')
class HaloClip(Filter, ClsMod):
lib = code.filters.halocliplib
def apply(self, fb, gprof, params, dim, tc, stream=None):
@ -118,6 +132,7 @@ def calc_lingam(params, tc):
lingam = f32(lin ** (gam-1.0) if lin > 0 else 0)
return gam, lin, lingam
@Filter.register('smearclip')
class SmearClip(Filter, ClsMod):
full_side = True
lib = code.filters.smearcliplib
@ -144,6 +159,7 @@ class SmearClip(Filter, ClsMod):
launch2('smearclip', self.mod, stream, dim,
fb.d_front, fb.d_side, f32(gam-1), lin, lingam)
@Filter.register('colorclip')
class ColorClip(Filter, ClsMod):
lib = code.filters.colorcliplib
def apply(self, fb, gprof, params, dim, tc, stream=None):
@ -154,8 +170,5 @@ class ColorClip(Filter, ClsMod):
launch2('colorclip', self.mod, stream, dim,
fb.d_front, vib, hipow, gam, lin, lingam)
# Ungainly but practical.
filter_map = dict(bilateral=Bilateral, logscale=Logscale, haloclip=HaloClip,
colorclip=ColorClip, smearclip=SmearClip)
def create(gprof):
return [filter_map[f]() for f in gprof.filter_order]
return [Filter.filter_map[f]() for f in gprof.filter_order]