Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
05de8ba2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
05de8ba2
编写于
12月 12, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): fix NormElemwisePass
GitOrigin-RevId: a92d19a013aba55fce4d5fd19798fc46d123fe97
上级
683b3e30
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
30 addition
and
18 deletion
+30
-18
imperative/python/megengine/traced_module/_passes/const_pass.py
...tive/python/megengine/traced_module/_passes/const_pass.py
+28
-16
imperative/python/megengine/traced_module/_passes/matcher.py
imperative/python/megengine/traced_module/_passes/matcher.py
+1
-1
imperative/python/test/unit/traced_module/test_passes.py
imperative/python/test/unit/traced_module/test_passes.py
+1
-1
未找到文件。
imperative/python/megengine/traced_module/_passes/const_pass.py
浏览文件 @
05de8ba2
...
@@ -128,10 +128,14 @@ class NormElemWise(BackwardPass):
...
@@ -128,10 +128,14 @@ class NormElemWise(BackwardPass):
cofee
,
left_node
,
right_node
=
1
,
None
,
None
cofee
,
left_node
,
right_node
=
1
,
None
,
None
if
len
(
expr
.
inputs
)
==
1
and
target
not
in
[
"__add__"
,
"__mul__"
]:
if
len
(
expr
.
inputs
)
==
1
and
target
not
in
[
"__add__"
,
"__mul__"
]:
left_node
=
expr
.
inputs
[
0
]
left_node
=
expr
.
inputs
[
0
]
right_node
=
expr
.
const_val
[
0
][
-
1
]
named_args
=
(
expr
.
named_args
).
values
()
for
v
in
named_args
:
if
not
isinstance
(
v
,
TensorNode
):
right_node
=
v
break
if
target
in
[
"__rsub__"
,
"__rtruediv__"
]:
if
target
in
[
"__rsub__"
,
"__rtruediv__"
]:
cofee
=
-
1
cofee
=
-
1
if
target
in
[
F
.
sub
,
F
.
div
]
and
left_node
is
not
expr
.
kw
args
[
"x"
]:
if
target
in
[
F
.
sub
,
F
.
div
]
and
left_node
is
not
expr
.
named_
args
[
"x"
]:
cofee
=
-
1
cofee
=
-
1
elif
len
(
expr
.
inputs
)
==
2
and
(
elif
len
(
expr
.
inputs
)
==
2
and
(
target
not
in
[
"__add__"
,
"__mul__"
]
or
is_constant
(
expr
.
inputs
[
0
].
expr
)
target
not
in
[
"__add__"
,
"__mul__"
]
or
is_constant
(
expr
.
inputs
[
0
].
expr
)
...
@@ -139,7 +143,7 @@ class NormElemWise(BackwardPass):
...
@@ -139,7 +143,7 @@ class NormElemWise(BackwardPass):
left_node
,
right_node
=
expr
.
inputs
left_node
,
right_node
=
expr
.
inputs
if
target
in
[
"__rsub__"
,
"__rtruediv__"
]:
if
target
in
[
"__rsub__"
,
"__rtruediv__"
]:
left_node
,
right_node
=
right_node
,
left_node
left_node
,
right_node
=
right_node
,
left_node
if
target
in
[
F
.
sub
,
F
.
div
]
and
left_node
is
not
expr
.
kw
args
[
"x"
]:
if
target
in
[
F
.
sub
,
F
.
div
]
and
left_node
is
not
expr
.
named_
args
[
"x"
]:
left_node
,
right_node
=
right_node
,
left_node
left_node
,
right_node
=
right_node
,
left_node
if
is_constant
(
left_node
.
expr
):
if
is_constant
(
left_node
.
expr
):
left_node
,
right_node
=
right_node
,
left_node
left_node
,
right_node
=
right_node
,
left_node
...
@@ -152,30 +156,38 @@ class NormElemWise(BackwardPass):
...
@@ -152,30 +156,38 @@ class NormElemWise(BackwardPass):
right_node
=
get_const_value
(
right_node
.
expr
,
right_node
)
right_node
=
get_const_value
(
right_node
.
expr
,
right_node
)
graph
=
expr
.
top_graph
graph
=
expr
.
top_graph
mul_f
,
add_f
,
sub_f
,
div_f
=
F
.
mul
,
F
.
add
,
F
.
sub
,
F
.
div
def
map_f
(
value
,
func
):
if
isinstance
(
value
,
(
list
,
tuple
)):
return
[
func
(
v
)
for
v
in
value
]
return
func
(
value
)
with
graph
.
insert_exprs
():
with
graph
.
insert_exprs
():
if
target
in
[
"__mul__"
,
"__imul__"
,
"__rmul__"
,
F
.
mul
]:
if
target
in
[
"__mul__"
,
"__imul__"
,
"__rmul__"
,
mul_f
]:
out_node
=
left_node
*
right_node
out_node
=
left_node
*
right_node
elif
target
in
[
"__add__"
,
"__iadd__"
,
"__radd__"
,
F
.
add
]:
elif
target
in
[
"__add__"
,
"__iadd__"
,
"__radd__"
,
add_f
]:
out_node
=
left_node
+
right_node
out_node
=
left_node
+
right_node
elif
target
in
[
"__sub__"
,
"__isub__"
,
"__rsub__"
,
F
.
sub
]:
elif
target
in
[
"__sub__"
,
"__isub__"
,
"__rsub__"
,
sub_f
]:
f_l
,
f_r
=
lambda
v
:
v
,
lambda
v
:
v
if
cofee
==
-
1
:
if
cofee
==
-
1
:
left_node
=
F
.
neg
(
left_node
)
f_l
=
lambda
v
:
F
.
neg
(
v
)
else
:
else
:
if
isinstance
(
right_node
,
TensorNode
):
if
isinstance
(
right_node
,
TensorNode
):
right_node
=
F
.
neg
(
right_node
)
f_r
=
lambda
v
:
F
.
neg
(
v
)
else
:
else
:
right_node
=
-
1
*
right_node
f_r
=
lambda
v
:
-
1
*
v
out_node
=
left_node
+
right_node
out_node
=
map_f
(
left_node
,
f_l
)
+
map_f
(
right_node
,
f_r
)
elif
target
in
[
"__truediv__"
,
"__itruediv__"
,
"__rtruediv__"
,
F
.
div
]:
elif
target
in
[
"__truediv__"
,
"__itruediv__"
,
"__rtruediv__"
,
div_f
]:
f_l
,
f_r
=
lambda
v
:
v
,
lambda
v
:
v
if
cofee
==
-
1
:
if
cofee
==
-
1
:
left_node
=
F
.
pow
(
left_node
,
-
1
)
f_l
=
lambda
v
:
F
.
pow
(
v
,
-
1
)
else
:
else
:
if
isinstance
(
right_node
,
TensorNode
):
if
isinstance
(
right_node
,
TensorNode
):
right_node
=
F
.
pow
(
right_node
,
-
1
)
f_r
=
lambda
v
:
F
.
pow
(
v
,
-
1
)
else
:
else
:
right_node
=
1
/
right_node
f_r
=
lambda
v
:
1
/
v
out_node
=
left_node
*
right_node
out_node
=
map_f
(
left_node
,
f_l
)
*
map_f
(
right_node
,
f_r
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
out_node
})
graph
.
replace_node
({
expr
.
outputs
[
0
]:
out_node
})
graph
.
compile
()
graph
.
compile
()
return
out_node
.
expr
return
out_node
.
expr
imperative/python/megengine/traced_module/_passes/matcher.py
浏览文件 @
05de8ba2
...
@@ -145,7 +145,7 @@ class PatternMatcher:
...
@@ -145,7 +145,7 @@ class PatternMatcher:
def
_visit_function_pattern
(
self
,
pattern
:
FunctionPattern
,
expr
:
Expr
)
->
bool
:
def
_visit_function_pattern
(
self
,
pattern
:
FunctionPattern
,
expr
:
Expr
)
->
bool
:
if
not
is_call_function
(
expr
,
pattern
.
target
):
if
not
is_call_function
(
expr
,
pattern
.
target
):
return
False
return
False
kwargs
=
expr
.
kw
args
kwargs
=
expr
.
named_
args
for
key
,
target
in
pattern
.
params
.
items
():
for
key
,
target
in
pattern
.
params
.
items
():
value
=
kwargs
.
get
(
key
,
None
)
value
=
kwargs
.
get
(
key
,
None
)
if
target
!=
value
:
if
target
!=
value
:
...
...
imperative/python/test/unit/traced_module/test_passes.py
浏览文件 @
05de8ba2
...
@@ -36,7 +36,7 @@ class MyBlock(M.Module):
...
@@ -36,7 +36,7 @@ class MyBlock(M.Module):
x2
=
F
.
relu
(
x2
)
x2
=
F
.
relu
(
x2
)
x2
=
x2
*
self
.
scale
[
1
]
x2
=
x2
*
self
.
scale
[
1
]
y
=
x1
+
x2
y
=
x1
+
x2
y
=
y
+
4
y
=
F
.
add
(
y
,
4
)
y
=
self
.
scale
[
0
]
+
y
y
=
self
.
scale
[
0
]
+
y
y
=
F
.
relu
(
y
)
*
3
y
=
F
.
relu
(
y
)
*
3
return
y
return
y
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录