提交 fe5649e4 编写于 作者: M Megvii Engine Team

fix(mge/imperative): remove duplicated opr

GitOrigin-RevId: 7d49785fad3674d18fdf00ca4c25cc2d923d1ea2
上级 add3a1bc
...@@ -14,7 +14,6 @@ import os ...@@ -14,7 +14,6 @@ import os
import textwrap import textwrap
import inspect import inspect
def camel2underscore( def camel2underscore(
name, *, name, *,
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'), first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
...@@ -50,9 +49,9 @@ class Context: ...@@ -50,9 +49,9 @@ class Context:
def __init__(self): def __init__(self):
self.fout = StringIO() self.fout = StringIO()
self.indent = 0 self.indent = 0
self.generated = []
self.skipped = [] self.skipped = []
self.generated_signature = set() self.generated_signature = set()
self.generated_opr = dict()
def write(self, text, *fmt, indent=0): def write(self, text, *fmt, indent=0):
text = textwrap.dedent(text) text = textwrap.dedent(text)
...@@ -181,6 +180,15 @@ class Context: ...@@ -181,6 +180,15 @@ class Context:
:param outputs: the indices of output vars to be selected from raw opr :param outputs: the indices of output vars to be selected from raw opr
result result
""" """
class OprItem:
def __init__(self, inputs, desc, params, version, has_out_dtype):
self.inputs = inputs
self.desc = desc
self.params = params
self.version = version
self.has_out_dtype = has_out_dtype
if body: if body:
self.skipped.append(name) self.skipped.append(name)
return return
...@@ -197,29 +205,56 @@ class Context: ...@@ -197,29 +205,56 @@ class Context:
params = [('param', params)] params = [('param', params)]
assert params assert params
self.write('# %s', caller_lineno()) if name in self.generated_opr:
self.write('class %s(PodOpVisitor):', name) org_opr = self.generated_opr[name]
self.indent += 1 if version > org_opr.version:
def compare_doc(a, b):
if isinstance(a, str):
return a == b
else:
assert isinstance(a, Doc)
return a.doc == b.doc
assert compare_doc(desc, org_opr.desc)
assert len(inputs) == len(org_opr.inputs)
for i, j in zip(inputs, org_opr.inputs):
assert compare_doc(i, j)
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
else:
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
def write_generated_oprs(self):
for opr, opr_item in self.generated_opr.items():
name = opr
params = opr_item.params
version = opr_item.version
has_out_dtype = opr_item.has_out_dtype
self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1
param_names, _ = zip(*params) param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n') self.write('\n')
self.write('def __init__(%s):', self.write('def __init__(%s):',
self._gen_signature(params, self._gen_signature(params,
has_out_dtype=has_out_dtype)) has_out_dtype=has_out_dtype))
self.indent += 1 self.indent += 1
self._write_gen_config(has_out_dtype=has_out_dtype) self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n') self.write('\n')
self._write_make_params(params) self._write_make_params(params)
self.write('\n') self.write('\n')
self.indent -= 2 self.indent -= 2
self.generated.append(name)
def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None, def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
desc=None, local_defs=[], have_config=True): desc=None, local_defs=[], have_config=True):
...@@ -232,7 +267,7 @@ class Context: ...@@ -232,7 +267,7 @@ class Context:
buf = StringIO() buf = StringIO()
print( print(
'[', '[',
*(' "%s",' % i for i in self.generated), *(' "%s",' % i for i in self.generated_opr),
']', ']',
sep='\n', sep='\n',
file=buf file=buf
...@@ -259,6 +294,7 @@ def main(): ...@@ -259,6 +294,7 @@ def main():
with open(i) as fin: with open(i) as fin:
exec(compile(fin.read(), i, 'exec'), exec_globals) exec(compile(fin.read(), i, 'exec'), exec_globals)
gen.write_generated_oprs()
try: try:
git_commit = subprocess.check_output( git_commit = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], universal_newlines=True, ['git', 'rev-parse', 'HEAD'], universal_newlines=True,
......
...@@ -95,6 +95,7 @@ r""" ...@@ -95,6 +95,7 @@ r"""
""")) """))
decl_opr('Local', decl_opr('Local',
pyname='local',
inputs=[Doc('src', inputs=[Doc('src',
'input image in (batch, channel, row, col) format'), 'input image in (batch, channel, row, col) format'),
Doc('filter', Doc('filter',
...@@ -105,6 +106,19 @@ decl_opr('Local', ...@@ -105,6 +106,19 @@ decl_opr('Local',
desc='batched convolution on channeled 2D images, but kernels are ' desc='batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions') 'not shared across different output positions')
decl_opr('Local',
pyname='local_v1',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
Doc('filter',
'convolution kernel in '
'(out row, out col, in channel, '
'kern row, kern col, out channel) format')],
params='Convolution',
desc='batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions',
version=1)
decl_opr('GroupLocal', decl_opr('GroupLocal',
inputs=[Doc('src', inputs=[Doc('src',
'input image in (batch, channel, row, col) format'), 'input image in (batch, channel, row, col) format'),
...@@ -113,7 +127,7 @@ decl_opr('GroupLocal', ...@@ -113,7 +127,7 @@ decl_opr('GroupLocal',
'(group, out row, out col, in channel / group, ' '(group, out row, out col, in channel / group, '
'kern row, kern col, out channel / group) format')], 'kern row, kern col, out channel / group) format')],
params=[('param', 'Convolution')], params=[('param', 'Convolution')],
desc='batched convolution on groupped channeled 2D images, but ' desc='batched convolution on groupped channeled 2D images, but '
'kernels are not shared across different output positions', 'kernels are not shared across different output positions',
version=1) version=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册