Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fe5649e4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fe5649e4
编写于
9月 22, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/imperative): remove duplicated opr
GitOrigin-RevId: 7d49785fad3674d18fdf00ca4c25cc2d923d1ea2
上级
add3a1bc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
71 addition
and
21 deletion
+71
-21
imperative/python/tools/gen_ops.py
imperative/python/tools/gen_ops.py
+56
-20
src/opr/impl/dnn/dnn.oprdecl
src/opr/impl/dnn/dnn.oprdecl
+15
-1
未找到文件。
imperative/python/tools/gen_ops.py
浏览文件 @
fe5649e4
...
...
@@ -14,7 +14,6 @@ import os
import
textwrap
import
inspect
def
camel2underscore
(
name
,
*
,
first_cap_re
=
re
.
compile
(
'([A-Z])([A-Z][a-z]+)'
),
...
...
@@ -50,9 +49,9 @@ class Context:
def
__init__
(
self
):
self
.
fout
=
StringIO
()
self
.
indent
=
0
self
.
generated
=
[]
self
.
skipped
=
[]
self
.
generated_signature
=
set
()
self
.
generated_opr
=
dict
()
def
write
(
self
,
text
,
*
fmt
,
indent
=
0
):
text
=
textwrap
.
dedent
(
text
)
...
...
@@ -181,6 +180,15 @@ class Context:
:param outputs: the indices of output vars to be selected from raw opr
result
"""
class
OprItem
:
def
__init__
(
self
,
inputs
,
desc
,
params
,
version
,
has_out_dtype
):
self
.
inputs
=
inputs
self
.
desc
=
desc
self
.
params
=
params
self
.
version
=
version
self
.
has_out_dtype
=
has_out_dtype
if
body
:
self
.
skipped
.
append
(
name
)
return
...
...
@@ -197,29 +205,56 @@ class Context:
params
=
[(
'param'
,
params
)]
assert
params
self
.
write
(
'# %s'
,
caller_lineno
())
self
.
write
(
'class %s(PodOpVisitor):'
,
name
)
self
.
indent
+=
1
if
name
in
self
.
generated_opr
:
org_opr
=
self
.
generated_opr
[
name
]
if
version
>
org_opr
.
version
:
def
compare_doc
(
a
,
b
):
if
isinstance
(
a
,
str
):
return
a
==
b
else
:
assert
isinstance
(
a
,
Doc
)
return
a
.
doc
==
b
.
doc
assert
compare_doc
(
desc
,
org_opr
.
desc
)
assert
len
(
inputs
)
==
len
(
org_opr
.
inputs
)
for
i
,
j
in
zip
(
inputs
,
org_opr
.
inputs
):
assert
compare_doc
(
i
,
j
)
self
.
generated_opr
[
name
]
=
OprItem
(
inputs
,
desc
,
params
,
version
,
has_out_dtype
)
else
:
self
.
generated_opr
[
name
]
=
OprItem
(
inputs
,
desc
,
params
,
version
,
has_out_dtype
)
def
write_generated_oprs
(
self
):
for
opr
,
opr_item
in
self
.
generated_opr
.
items
():
name
=
opr
params
=
opr_item
.
params
version
=
opr_item
.
version
has_out_dtype
=
opr_item
.
has_out_dtype
self
.
write
(
'# %s'
,
caller_lineno
())
self
.
write
(
'class %s(PodOpVisitor):'
,
name
)
self
.
indent
+=
1
param_names
,
_
=
zip
(
*
params
)
self
.
write
(
'param_names = (%s,)'
,
', '
.
join
(
map
(
'"{}"'
.
format
,
param_names
)))
self
.
write
(
'name = "%s"'
,
'{}V{}'
.
format
(
name
,
version
)
if
version
else
name
)
self
.
write
(
'
\n
'
)
param_names
,
_
=
zip
(
*
params
)
self
.
write
(
'param_names = (%s,)'
,
', '
.
join
(
map
(
'"{}"'
.
format
,
param_names
)))
self
.
write
(
'name = "%s"'
,
'{}V{}'
.
format
(
name
,
version
)
if
version
else
name
)
self
.
write
(
'
\n
'
)
self
.
write
(
'def __init__(%s):'
,
self
.
_gen_signature
(
params
,
has_out_dtype
=
has_out_dtype
))
self
.
indent
+=
1
self
.
write
(
'def __init__(%s):'
,
self
.
_gen_signature
(
params
,
has_out_dtype
=
has_out_dtype
))
self
.
indent
+=
1
self
.
_write_gen_config
(
has_out_dtype
=
has_out_dtype
)
self
.
write
(
'
\n
'
)
self
.
_write_gen_config
(
has_out_dtype
=
has_out_dtype
)
self
.
write
(
'
\n
'
)
self
.
_write_make_params
(
params
)
self
.
_write_make_params
(
params
)
self
.
write
(
'
\n
'
)
self
.
indent
-=
2
self
.
write
(
'
\n
'
)
self
.
indent
-=
2
self
.
generated
.
append
(
name
)
def
decl_raw_opr
(
self
,
name
,
*
,
inputs
,
inputs_cvt
=
[],
body
=
None
,
desc
=
None
,
local_defs
=
[],
have_config
=
True
):
...
...
@@ -232,7 +267,7 @@ class Context:
buf
=
StringIO
()
print
(
'['
,
*
(
' "%s",'
%
i
for
i
in
self
.
generated
),
*
(
' "%s",'
%
i
for
i
in
self
.
generated
_opr
),
']'
,
sep
=
'
\n
'
,
file
=
buf
...
...
@@ -259,6 +294,7 @@ def main():
with
open
(
i
)
as
fin
:
exec
(
compile
(
fin
.
read
(),
i
,
'exec'
),
exec_globals
)
gen
.
write_generated_oprs
()
try
:
git_commit
=
subprocess
.
check_output
(
[
'git'
,
'rev-parse'
,
'HEAD'
],
universal_newlines
=
True
,
...
...
src/opr/impl/dnn/dnn.oprdecl
浏览文件 @
fe5649e4
...
...
@@ -95,6 +95,7 @@ r"""
"""
))
decl_opr
(
'Local'
,
pyname
=
'local'
,
inputs
=
[
Doc
(
'src'
,
'input image in (batch, channel, row, col) format'
),
Doc
(
'filter'
,
...
...
@@ -105,6 +106,19 @@ decl_opr('Local',
desc
=
'batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions'
)
decl_opr
(
'Local'
,
pyname
=
'local_v1'
,
inputs
=
[
Doc
(
'src'
,
'input image in (batch, channel, row, col) format'
),
Doc
(
'filter'
,
'convolution kernel in '
'(out row, out col, in channel, '
'kern row, kern col, out channel) format'
)],
params
=
'Convolution'
,
desc
=
'batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions'
,
version
=
1
)
decl_opr
(
'GroupLocal'
,
inputs
=
[
Doc
(
'src'
,
'input image in (batch, channel, row, col) format'
),
...
...
@@ -113,7 +127,7 @@ decl_opr('GroupLocal',
'(group, out row, out col, in channel / group, '
'kern row, kern col, out channel / group) format'
)],
params
=
[(
'param'
,
'Convolution'
)],
desc
=
'batched convolution on groupped channeled 2D images, but '
desc
=
'batched convolution on groupped channeled 2D images, but '
'kernels are not shared across different output positions'
,
version
=
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录