diff --git a/imperative/python/tools/gen_ops.py b/imperative/python/tools/gen_ops.py index 7fcdb422a8aed5931ef9b28484dae1de7f9d550d..dde93e983423ec477276d6bdeddba853637c6778 100755 --- a/imperative/python/tools/gen_ops.py +++ b/imperative/python/tools/gen_ops.py @@ -14,7 +14,6 @@ import os import textwrap import inspect - def camel2underscore( name, *, first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'), @@ -50,9 +49,9 @@ class Context: def __init__(self): self.fout = StringIO() self.indent = 0 - self.generated = [] self.skipped = [] self.generated_signature = set() + self.generated_opr = dict() def write(self, text, *fmt, indent=0): text = textwrap.dedent(text) @@ -181,6 +180,15 @@ class Context: :param outputs: the indices of output vars to be selected from raw opr result """ + + class OprItem: + def __init__(self, inputs, desc, params, version, has_out_dtype): + self.inputs = inputs + self.desc = desc + self.params = params + self.version = version + self.has_out_dtype = has_out_dtype + if body: self.skipped.append(name) return @@ -197,29 +205,56 @@ class Context: params = [('param', params)] assert params - self.write('# %s', caller_lineno()) - self.write('class %s(PodOpVisitor):', name) - self.indent += 1 + if name in self.generated_opr: + org_opr = self.generated_opr[name] + if version > org_opr.version: + def compare_doc(a, b): + if isinstance(a, str): + return a == b + else: + assert isinstance(a, Doc) + return a.doc == b.doc + + assert compare_doc(desc, org_opr.desc) + assert len(inputs) == len(org_opr.inputs) + for i, j in zip(inputs, org_opr.inputs): + assert compare_doc(i, j) + + self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) + else: + self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) + + def write_generated_oprs(self): + + for opr, opr_item in self.generated_opr.items(): + + name = opr + params = opr_item.params + version = opr_item.version + has_out_dtype = opr_item.has_out_dtype + + self.write('# %s', caller_lineno()) + self.write('class %s(PodOpVisitor):', name) + self.indent += 1 - param_names, _ = zip(*params) - self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) - self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) - self.write('\n') + param_names, _ = zip(*params) + self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) + self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) + self.write('\n') - self.write('def __init__(%s):', - self._gen_signature(params, - has_out_dtype=has_out_dtype)) - self.indent += 1 + self.write('def __init__(%s):', + self._gen_signature(params, + has_out_dtype=has_out_dtype)) + self.indent += 1 - self._write_gen_config(has_out_dtype=has_out_dtype) - self.write('\n') + self._write_gen_config(has_out_dtype=has_out_dtype) + self.write('\n') - self._write_make_params(params) + self._write_make_params(params) - self.write('\n') - self.indent -= 2 + self.write('\n') + self.indent -= 2 - self.generated.append(name) def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None, desc=None, local_defs=[], have_config=True): @@ -232,7 +267,7 @@ class Context: buf = StringIO() print( '[', - *(' "%s",' % i for i in self.generated), + *(' "%s",' % i for i in self.generated_opr), ']', sep='\n', file=buf @@ -259,6 +294,7 @@ def main(): with open(i) as fin: exec(compile(fin.read(), i, 'exec'), exec_globals) + gen.write_generated_oprs() try: git_commit = subprocess.check_output( ['git', 'rev-parse', 'HEAD'], universal_newlines=True, diff --git a/src/opr/impl/dnn/dnn.oprdecl b/src/opr/impl/dnn/dnn.oprdecl index cadc8da0098112f83df5bd88717c36115c6f8c41..af790c39e646b9247882911ed1f371f200674099 100644 --- a/src/opr/impl/dnn/dnn.oprdecl +++ b/src/opr/impl/dnn/dnn.oprdecl @@ -95,6 +95,7 @@ r""" """)) decl_opr('Local', + pyname='local', inputs=[Doc('src', 'input image in (batch, channel, row, col) format'), Doc('filter', @@ -105,6 +106,19 @@ decl_opr('Local', desc='batched convolution on channeled 2D images, but kernels are ' 'not shared across different output positions') +decl_opr('Local', + pyname='local_v1', + inputs=[Doc('src', + 'input image in (batch, channel, row, col) format'), + Doc('filter', + 'convolution kernel in ' + '(out row, out col, in channel, ' + 'kern row, kern col, out channel) format')], + params='Convolution', + desc='batched convolution on channeled 2D images, but kernels are ' + 'not shared across different output positions', + version=1) + decl_opr('GroupLocal', inputs=[Doc('src', 'input image in (batch, channel, row, col) format'), @@ -113,7 +127,7 @@ decl_opr('GroupLocal', '(group, out row, out col, in channel / group, ' 'kern row, kern col, out channel / group) format')], params=[('param', 'Convolution')], - desc='batched convolution on groupped channeled 2D images, but ' + desc='batched convolution on groupped channeled 2D images, but ' 'kernels are not shared across different output positions', version=1)