提交 fe5649e4 编写于 作者: M Megvii Engine Team

fix(mge/imperative): remove duplicated opr

GitOrigin-RevId: 7d49785fad3674d18fdf00ca4c25cc2d923d1ea2
上级 add3a1bc
......@@ -14,7 +14,6 @@ import os
import textwrap
import inspect
def camel2underscore(
name, *,
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
......@@ -50,9 +49,9 @@ class Context:
def __init__(self):
self.fout = StringIO()
self.indent = 0
self.generated = []
self.skipped = []
self.generated_signature = set()
self.generated_opr = dict()
def write(self, text, *fmt, indent=0):
text = textwrap.dedent(text)
......@@ -181,6 +180,15 @@ class Context:
:param outputs: the indices of output vars to be selected from raw opr
result
"""
class OprItem:
def __init__(self, inputs, desc, params, version, has_out_dtype):
self.inputs = inputs
self.desc = desc
self.params = params
self.version = version
self.has_out_dtype = has_out_dtype
if body:
self.skipped.append(name)
return
......@@ -197,29 +205,56 @@ class Context:
params = [('param', params)]
assert params
self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1
if name in self.generated_opr:
org_opr = self.generated_opr[name]
if version > org_opr.version:
def compare_doc(a, b):
if isinstance(a, str):
return a == b
else:
assert isinstance(a, Doc)
return a.doc == b.doc
assert compare_doc(desc, org_opr.desc)
assert len(inputs) == len(org_opr.inputs)
for i, j in zip(inputs, org_opr.inputs):
assert compare_doc(i, j)
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
else:
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
def write_generated_oprs(self):
for opr, opr_item in self.generated_opr.items():
name = opr
params = opr_item.params
version = opr_item.version
has_out_dtype = opr_item.has_out_dtype
self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1
param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n')
param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n')
self.write('def __init__(%s):',
self._gen_signature(params,
has_out_dtype=has_out_dtype))
self.indent += 1
self.write('def __init__(%s):',
self._gen_signature(params,
has_out_dtype=has_out_dtype))
self.indent += 1
self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n')
self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n')
self._write_make_params(params)
self._write_make_params(params)
self.write('\n')
self.indent -= 2
self.write('\n')
self.indent -= 2
self.generated.append(name)
def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
desc=None, local_defs=[], have_config=True):
......@@ -232,7 +267,7 @@ class Context:
buf = StringIO()
print(
'[',
*(' "%s",' % i for i in self.generated),
*(' "%s",' % i for i in self.generated_opr),
']',
sep='\n',
file=buf
......@@ -259,6 +294,7 @@ def main():
with open(i) as fin:
exec(compile(fin.read(), i, 'exec'), exec_globals)
gen.write_generated_oprs()
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], universal_newlines=True,
......
......@@ -95,6 +95,7 @@ r"""
"""))
decl_opr('Local',
pyname='local',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
Doc('filter',
......@@ -105,6 +106,19 @@ decl_opr('Local',
desc='batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions')
decl_opr('Local',
pyname='local_v1',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
Doc('filter',
'convolution kernel in '
'(out row, out col, in channel, '
'kern row, kern col, out channel) format')],
params='Convolution',
desc='batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions',
version=1)
decl_opr('GroupLocal',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
......@@ -113,7 +127,7 @@ decl_opr('GroupLocal',
'(group, out row, out col, in channel / group, '
'kern row, kern col, out channel / group) format')],
params=[('param', 'Convolution')],
desc='batched convolution on groupped channeled 2D images, but '
desc='batched convolution on groupped channeled 2D images, but '
'kernels are not shared across different output positions',
version=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册