diff --git a/cuburn/filters.py b/cuburn/filters.py index f55bfbd..5a2dd52 100644 --- a/cuburn/filters.py +++ b/cuburn/filters.py @@ -28,6 +28,9 @@ def mkdsc(dim, ch): format=cuda.array_format.FLOAT) class Filter(object): + filter_map = {} + name = '' + # Set to True if the filter requires a full 4-channel side buffer full_side = False def apply(self, fb, gprof, params, dim, tc, stream=None): @@ -39,6 +42,15 @@ class Filter(object): """ raise NotImplementedError() + @classmethod + def register(cls, name): + def register_(subcls): + cls.filter_map[name] = subcls + subcls.name = name + return register_ + + +@Filter.register('bilateral') class Bilateral(Filter, ClsMod): lib = code.filters.bilaterallib radius = 15 @@ -79,6 +91,7 @@ class Bilateral(Filter, ClsMod): texrefs=[tref, grad_tref]) fb.flip() +@Filter.register('logscale') class Logscale(Filter, ClsMod): lib = code.filters.logscalelib def apply(self, fb, gprof, params, dim, tc, stream=None): @@ -91,6 +104,7 @@ class Logscale(Filter, ClsMod): launch2('logscale', self.mod, stream, dim, fb.d_front, fb.d_front, k1, k2) +@Filter.register('haloclip') class HaloClip(Filter, ClsMod): lib = code.filters.halocliplib def apply(self, fb, gprof, params, dim, tc, stream=None): @@ -118,6 +132,7 @@ def calc_lingam(params, tc): lingam = f32(lin ** (gam-1.0) if lin > 0 else 0) return gam, lin, lingam +@Filter.register('smearclip') class SmearClip(Filter, ClsMod): full_side = True lib = code.filters.smearcliplib @@ -144,6 +159,7 @@ class SmearClip(Filter, ClsMod): launch2('smearclip', self.mod, stream, dim, fb.d_front, fb.d_side, f32(gam-1), lin, lingam) +@Filter.register('colorclip') class ColorClip(Filter, ClsMod): lib = code.filters.colorcliplib def apply(self, fb, gprof, params, dim, tc, stream=None): @@ -154,8 +170,5 @@ class ColorClip(Filter, ClsMod): launch2('colorclip', self.mod, stream, dim, fb.d_front, vib, hipow, gam, lin, lingam) -# Ungainly but practical. -filter_map = dict(bilateral=Bilateral, logscale=Logscale, haloclip=HaloClip, - colorclip=ColorClip, smearclip=SmearClip) def create(gprof): - return [filter_map[f]() for f in gprof.filter_order] + return [Filter.filter_map[f]() for f in gprof.filter_order]