Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6c692b26
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6c692b26
编写于
11月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/traced_module): add some fuse passes
GitOrigin-RevId: 065f9df32eaead53544989c826910f8c326ba738
上级
b28ad4e8
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
1050 addition
and
23 deletion
+1050
-23
imperative/python/megengine/traced_module/_passes/const_pass.py
...tive/python/megengine/traced_module/_passes/const_pass.py
+183
-0
imperative/python/megengine/traced_module/_passes/fold_scale_pass.py
...python/megengine/traced_module/_passes/fold_scale_pass.py
+298
-0
imperative/python/megengine/traced_module/_passes/fuse_pass.py
...ative/python/megengine/traced_module/_passes/fuse_pass.py
+248
-0
imperative/python/megengine/traced_module/_passes/pass_base.py
...ative/python/megengine/traced_module/_passes/pass_base.py
+190
-0
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+19
-9
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+1
-0
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+24
-0
imperative/python/megengine/utils/bn_fusion.py
imperative/python/megengine/utils/bn_fusion.py
+86
-0
imperative/python/test/unit/traced_module/test_qat_module.py
imperative/python/test/unit/traced_module/test_qat_module.py
+1
-14
未找到文件。
imperative/python/megengine/traced_module/_passes/const_pass.py
0 → 100644
浏览文件 @
6c692b26
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...
import
functional
as
F
from
...
import
module
as
M
from
...core.ops.builtin
import
GetVarShape
from
...logger
import
get_logger
from
...tensor
import
Tensor
from
..expr
import
Constant
,
Expr
,
is_apply_def
,
is_constant
,
is_getattr
from
..node
import
Node
,
TensorNode
from
.matcher
import
PatternMatcher
from
.pass_base
import
BackwardPass
,
ForwardPass
,
register_pass
from
.pattern
import
is_op
from
.utils
import
get_const_value
logger
=
get_logger
(
__name__
)
@
register_pass
(
"AttrToConstant"
)
class
AttrToConstant
(
BackwardPass
):
r
"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
name
=
"AttrToConstant"
run_once
=
True
def
run_transform
(
self
,
expr
:
Expr
):
if
not
(
is_getattr
(
expr
)
and
isinstance
(
expr
.
outputs
[
0
],
TensorNode
)):
return
expr
graph
=
expr
.
top_graph
value
=
get_const_value
(
expr
)
orig_node
=
expr
.
outputs
[
0
]
name
=
orig_node
.
name
with
graph
.
insert_exprs
(
expr
):
const_node
=
Constant
.
make
(
value
,
name
=
name
)
graph
.
replace_node
({
orig_node
:
const_node
})
graph
.
compile
()
name
=
orig_node
.
name
return
const_node
.
expr
@
register_pass
(
"FixInputShape"
)
class
FixInputShape
(
BackwardPass
):
name
=
"FixInputShape"
run_once
=
True
def
run_transform
(
self
,
expr
:
Expr
):
if
not
is_apply_def
(
expr
,
GetVarShape
):
return
expr
shape
=
Tensor
(
expr
.
inputs
[
0
].
shape
,
dtype
=
"int32"
)
graph
=
expr
.
top_graph
with
graph
.
insert_exprs
(
expr
):
const_shape
=
Constant
.
make
(
shape
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
const_shape
})
graph
.
compile
()
const_shape
.
name
=
expr
.
outputs
[
0
].
name
return
const_shape
.
expr
@
register_pass
(
"FlodConstant"
)
class
FlodConstant
(
ForwardPass
):
r
"""Constant folding."""
name
=
"FlodConstant"
required_pass
=
[
"AttrToConstant"
]
run_once
=
False
def
run_transform
(
self
,
expr
:
Expr
):
if
len
(
expr
.
inputs
)
==
0
or
any
(
not
is_constant
(
n
.
expr
)
for
n
in
expr
.
inputs
):
return
expr
const_var
=
expr
.
interpret
(
*
[
get_const_value
(
n
.
expr
)
for
n
in
expr
.
inputs
])[
0
]
graph
=
expr
.
top_graph
with
graph
.
insert_exprs
(
expr
):
const_node
=
Constant
.
make
(
const_var
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
const_node
})
graph
.
compile
()
const_node
.
name
=
expr
.
outputs
[
0
].
name
return
const_node
.
expr
@
register_pass
(
"NormElemWise"
)
class
NormElemWise
(
BackwardPass
):
r
"""Transform add/sub or mul/div expr to add-only or mul-only chains.
For example, the following code
.. code-block::
b = 1 - a
c = 2 * b
d = 1 / c
will be changed to
.. code-block::
a1 = F.neg(a)
b = a1 + 1
c = b * 2
d = F.pow(d, -1)
"""
name
=
"NormElemWise"
required_pass
=
[
"FlodConstant"
]
run_once
=
False
def
__init__
(
self
,):
super
().
__init__
()
self
.
pattern
=
is_op
(
F
.
add
)
for
op
in
[
F
.
sub
,
F
.
mul
,
F
.
div
]:
self
.
pattern
|=
is_op
(
op
)
for
op
in
[
"__add__"
,
"__iadd__"
,
"__radd__"
]:
self
.
pattern
|=
is_op
(
op
)
for
op
in
[
"__sub__"
,
"__isub__"
,
"__rsub__"
]:
self
.
pattern
|=
is_op
(
op
)
for
op
in
[
"__mul__"
,
"__imul__"
,
"__rmul__"
]:
self
.
pattern
|=
is_op
(
op
)
for
op
in
[
"__truediv__"
,
"__itruediv__"
,
"__rtruediv__"
]:
self
.
pattern
|=
is_op
(
op
)
def
run_transform
(
self
,
expr
:
Expr
):
matcher
=
PatternMatcher
()
if
not
matcher
.
match
(
self
.
pattern
,
expr
):
return
expr
pattern
=
matcher
.
matched_patterns
[
0
]
target
=
pattern
.
target
cofee
,
left_node
,
right_node
=
1
,
None
,
None
if
len
(
expr
.
inputs
)
==
1
and
target
not
in
[
"__add__"
,
"__mul__"
]:
left_node
=
expr
.
inputs
[
0
]
right_node
=
expr
.
const_val
[
0
][
-
1
]
if
target
in
[
"__rsub__"
,
"__rtruediv__"
]:
cofee
=
-
1
if
target
in
[
F
.
sub
,
F
.
div
]
and
left_node
is
not
expr
.
kwargs
[
"x"
]:
cofee
=
-
1
elif
len
(
expr
.
inputs
)
==
2
and
(
target
not
in
[
"__add__"
,
"__mul__"
]
or
is_constant
(
expr
.
inputs
[
0
].
expr
)
):
left_node
,
right_node
=
expr
.
inputs
if
target
in
[
"__rsub__"
,
"__rtruediv__"
]:
left_node
,
right_node
=
right_node
,
left_node
if
target
in
[
F
.
sub
,
F
.
div
]
and
left_node
is
not
expr
.
kwargs
[
"x"
]:
left_node
,
right_node
=
right_node
,
left_node
if
is_constant
(
left_node
.
expr
):
left_node
,
right_node
=
right_node
,
left_node
cofee
=
-
1
if
left_node
is
None
:
return
expr
if
isinstance
(
right_node
,
TensorNode
):
right_node
=
get_const_value
(
right_node
.
expr
,
right_node
)
graph
=
expr
.
top_graph
with
graph
.
insert_exprs
():
if
target
in
[
"__mul__"
,
"__imul__"
,
"__rmul__"
,
F
.
mul
]:
out_node
=
left_node
*
right_node
elif
target
in
[
"__add__"
,
"__iadd__"
,
"__radd__"
,
F
.
add
]:
out_node
=
left_node
+
right_node
elif
target
in
[
"__sub__"
,
"__isub__"
,
"__rsub__"
,
F
.
sub
]:
if
cofee
==
-
1
:
left_node
=
F
.
neg
(
left_node
)
else
:
if
isinstance
(
right_node
,
TensorNode
):
right_node
=
F
.
neg
(
right_node
)
else
:
right_node
=
-
1
*
right_node
out_node
=
left_node
+
right_node
elif
target
in
[
"__truediv__"
,
"__itruediv__"
,
"__rtruediv__"
,
F
.
div
]:
if
cofee
==
-
1
:
left_node
=
F
.
pow
(
left_node
,
-
1
)
else
:
if
isinstance
(
right_node
,
TensorNode
):
right_node
=
F
.
pow
(
right_node
,
-
1
)
else
:
right_node
=
1
/
right_node
out_node
=
left_node
*
right_node
graph
.
replace_node
({
expr
.
outputs
[
0
]:
out_node
})
graph
.
compile
()
return
out_node
.
expr
imperative/python/megengine/traced_module/_passes/fold_scale_pass.py
0 → 100644
浏览文件 @
6c692b26
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
collections
import
OrderedDict
,
defaultdict
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Set
from
...
import
functional
as
F
from
...
import
module
as
M
from
...core.ops.builtin
import
GetVarShape
from
...logger
import
get_logger
from
...tensor
import
Parameter
,
Tensor
from
..expr
import
(
Expr
,
is_apply_def
,
is_call_function
,
is_call_module
,
is_call_tensor_method
,
is_constant
,
is_getattr
,
)
from
..traced_module
import
InternalGraph
from
..utils
import
assign_attr
,
get_subattr
from
.matcher
import
PatternMatcher
from
.pass_base
import
BackwardPass
,
register_pass
from
.pattern
import
is_const
,
is_op
,
is_var
from
.utils
import
get_const_value
logger
=
get_logger
(
__name__
)
@
register_pass
(
"BackwardFoldScale"
)
class
BackwardFoldScale
(
BackwardPass
):
r
"""Backward fold const scaling into weights of conv2d.
For example, the following code
.. code-block::
x = conv(x, w, b)
x = relu(x)
x1 = x + 3
x2 = x + 4
y = (x1 + x2) * 3
will be changed to
.. code-block::
x = conv(x, w * 3, b * 3)
x = relu(x)
x1 = x + 9
x2 = x + 12
y = x1 + x2
"""
name
=
"BackwardFoldScale"
required_pass
=
[
"AttrToConstant"
,
"NormElemWise"
]
run_once
=
True
def
__init__
(
self
):
super
().
__init__
()
# todo : supoort more axis
self
.
scale_message
=
OrderedDict
()
self
.
used_names
=
defaultdict
(
int
)
def
run_transform
(
self
,
expr
:
Expr
)
->
Expr
:
if
expr
not
in
self
.
scale_message
:
return
expr
var
=
is_var
().
check_users
(
False
)
mul_const_pattern
=
var
*
is_const
()
|
var
*
"*"
|
is_op
(
F
.
neg
)
add_const_pattern
=
var
+
is_const
()
|
var
+
"*"
conv_pattern
=
is_op
(
F
.
conv2d
)
|
is_op
(
M
.
Conv2d
)
pattern
=
conv_pattern
|
add_const_pattern
|
mul_const_pattern
macther
=
PatternMatcher
()
if
not
macther
.
match
(
pattern
,
expr
):
return
expr
macther_exprs
=
macther
.
matched_exprs
if
conv_pattern
in
macther_exprs
:
return
self
.
fold_conv_mul
(
expr
)
if
mul_const_pattern
in
macther_exprs
:
return
self
.
fold_mul
(
expr
)
if
add_const_pattern
in
macther_exprs
:
return
self
.
fold_add_mul
(
expr
)
return
expr
def
fold_add_mul
(
self
,
expr
:
Expr
):
if
self
.
scale_message
[
expr
]
is
None
:
return
expr
scale
=
self
.
scale_message
[
expr
]
if
len
(
expr
.
inputs
)
==
1
:
const
=
expr
.
const_val
[
0
][
-
1
]
else
:
const
=
get_const_value
(
expr
.
inputs
[
1
])
const
=
const
*
scale
inp_node
=
expr
.
inputs
[
0
]
graph
=
expr
.
top_graph
with
graph
.
insert_exprs
():
add_node
=
inp_node
+
const
graph
.
replace_node
({
expr
.
outputs
[
0
]:
add_node
})
graph
.
compile
()
add_node
.
name
=
expr
.
outputs
[
0
].
name
return
add_node
.
expr
def
fold_mul
(
self
,
expr
:
Expr
):
if
self
.
scale_message
[
expr
]
is
None
:
return
expr
graph
=
expr
.
top_graph
graph
.
replace_node
({
expr
.
outputs
[
0
]:
expr
.
inputs
[
0
]})
graph
.
compile
()
return
expr
def
fold_conv_mul
(
self
,
expr
:
Expr
):
graph
=
expr
.
top_graph
scale
=
self
.
scale_message
[
expr
]
if
scale
is
None
:
return
expr
if
is_call_function
(
expr
,
F
.
conv2d
):
named_args
=
expr
.
named_args
weight
=
get_const_value
(
named_args
[
"weight"
],
named_args
[
"weight"
])
*
scale
bias
=
get_const_value
(
named_args
[
"bias"
],
named_args
[
"bias"
])
*
scale
named_args
[
"weight"
]
=
weight
named_args
[
"bias"
]
=
bias
with
graph
.
insert_exprs
():
out_node
=
F
.
conv2d
(
**
named_args
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
out_node
})
graph
.
compile
()
out_node
.
name
=
expr
.
outputs
[
0
].
name
return
out_node
.
expr
else
:
mnode
=
expr
.
inputs
[
0
]
attr_name
=
expr
.
inputs
[
0
].
expr
.
name
graph
=
expr
.
top_graph
if
len
(
mnode
.
users
)
>
1
:
self
.
used_names
[
mnode
.
qualname
]
+=
1
attr_name
=
"{}_{}"
.
format
(
attr_name
,
self
.
used_names
[
mnode
.
qualname
])
logger
.
warning
(
"{} is used {} times and its name will be reset to {}.{}"
.
format
(
mnode
.
qualname
,
len
(
mnode
.
users
),
graph
.
qualname
,
attr_name
)
)
conv_module
=
mnode
.
owner
if
len
(
mnode
.
users
)
>
1
:
conv_module
=
deepcopy
(
conv_module
)
conv_module
.
_name
=
None
conv_module
.
weight
=
Parameter
(
conv_module
.
weight
*
scale
)
if
conv_module
.
bias
is
not
None
:
conv_module
.
bias
=
Parameter
(
conv_module
.
bias
*
scale
)
if
len
(
mnode
.
users
)
>
1
:
self_node
=
mnode
.
expr
.
inputs
[
0
]
assign_attr
(
conv_module
,
self_node
.
owner
,
attr_name
)
with
graph
.
insert_exprs
(
mnode
.
expr
):
new_conv_node
=
get_subattr
(
self_node
,
attr_name
)
expr
.
replace_inputs
({
mnode
:
new_conv_node
})
return
expr
def
reset_expr_message_to_none
(
self
,
expr
:
Expr
,
scale_message
:
Dict
[
Expr
,
Any
],
skip_exprs
:
Set
[
Expr
],
):
if
expr
in
skip_exprs
:
return
scale_message
[
expr
]
=
None
if
is_call_function
(
expr
,
F
.
conv2d
)
or
is_call_module
(
expr
,
M
.
Conv2d
):
return
for
out_node
in
expr
.
outputs
:
for
user
in
out_node
.
users
:
if
user
in
scale_message
:
self
.
reset_expr_message_to_none
(
user
,
scale_message
,
skip_exprs
)
def
before_visit_graph
(
self
,
graph
:
InternalGraph
):
var
=
is_var
().
check_users
(
False
)
mul_const_pattern
=
var
*
is_const
()
|
var
*
"*"
|
is_op
(
F
.
neg
)
relu_pattern
=
(
is_op
(
F
.
relu
)
|
is_op
(
M
.
ReLU
)
|
is_op
(
F
.
leaky_relu
)
|
is_op
(
M
.
LeakyReLU
)
)
# The param of conv must be const, not support dynamic conv
conv_pattern
=
(
is_op
(
F
.
conv2d
)(
var
,
is_const
(),
is_const
())
|
is_op
(
F
.
conv2d
)(
var
,
is_const
())
|
is_op
(
M
.
Conv2d
)
)
pattern
=
mul_const_pattern
|
relu_pattern
|
conv_pattern
for
op
in
[
"__add__"
,
F
.
reshape
,
"reshape"
,
F
.
transpose
,
"tranpose"
,
F
.
min
,
"min"
,
F
.
max
,
"max"
,
F
.
max_pool2d
,
M
.
MaxPool2d
,
F
.
avg_pool2d
,
M
.
AvgPool2d
,
F
.
adaptive_avg_pool2d
,
M
.
AdaptiveAvgPool2d
,
F
.
adaptive_max_pool2d
,
M
.
AdaptiveMaxPool2d
,
F
.
expand_dims
,
F
.
concat
,
"__getitem__"
,
]:
pattern
|=
is_op
(
op
)
matcher
=
PatternMatcher
()
scale_message
=
OrderedDict
()
mem_conv_scale_message
=
OrderedDict
()
skip_exprs
=
self
.
init_skip_exprs
(
graph
)
for
expr
in
reversed
(
graph
.
_exprs
):
if
expr
in
skip_exprs
:
continue
if
len
(
expr
.
outputs
)
>
1
or
not
matcher
.
match
(
pattern
,
expr
):
self
.
reset_expr_message_to_none
(
expr
,
scale_message
,
skip_exprs
)
if
is_call_function
(
expr
,
F
.
conv2d
):
for
user
in
expr
.
outputs
[
0
].
users
:
self
.
reset_expr_message_to_none
(
user
,
scale_message
,
skip_exprs
)
continue
matched_exprs
=
matcher
.
matched_exprs
const
=
None
if
mul_const_pattern
in
matched_exprs
:
if
is_call_function
(
expr
,
F
.
neg
):
const
=
-
1
elif
len
(
expr
.
inputs
)
==
1
:
const
=
expr
.
const_val
[
0
][
-
1
]
else
:
const
=
get_const_value
(
expr
.
inputs
[
1
])
if
isinstance
(
const
,
Tensor
)
and
const
.
_tuple_shape
not
in
[(
1
,),
tuple
()]:
self
.
reset_expr_message_to_none
(
expr
,
scale_message
,
skip_exprs
)
continue
users_const
=
[
scale_message
[
e
]
for
e
in
expr
.
outputs
[
0
].
users
if
e
not
in
skip_exprs
]
if
len
(
users_const
)
==
0
:
scale_message
[
expr
]
=
const
continue
if
any
(
c
is
None
or
c
!=
users_const
[
0
]
for
c
in
users_const
):
self
.
reset_expr_message_to_none
(
expr
,
scale_message
,
skip_exprs
)
scale_message
[
expr
]
=
const
continue
const
=
1
if
const
is
None
else
const
const
=
const
*
users_const
[
0
]
if
relu_pattern
in
matched_exprs
and
const
<
0
:
self
.
reset_expr_message_to_none
(
expr
,
scale_message
,
skip_exprs
)
continue
if
conv_pattern
in
matched_exprs
:
self
.
reset_expr_message_to_none
(
expr
,
scale_message
,
skip_exprs
)
mem_conv_scale_message
[
expr
]
=
const
continue
scale_message
[
expr
]
=
const
self
.
scale_message
.
update
(
scale_message
)
self
.
scale_message
.
update
(
mem_conv_scale_message
)
def
init_skip_exprs
(
self
,
graph
:
InternalGraph
):
skip_exprs
=
set
()
for
expr
in
graph
.
_exprs
:
if
is_apply_def
(
expr
,
GetVarShape
):
skip_exprs
.
add
(
expr
)
elif
is_call_tensor_method
(
expr
,
"__getitem__"
)
and
expr
in
skip_exprs
:
skip_exprs
.
add
(
expr
)
elif
is_getattr
(
expr
):
skip_exprs
.
add
(
expr
)
elif
is_constant
(
expr
):
skip_exprs
.
add
(
expr
)
elif
all
(
n
.
expr
in
skip_exprs
for
n
in
expr
.
inputs
):
skip_exprs
.
add
(
expr
)
return
skip_exprs
imperative/python/megengine/traced_module/_passes/fuse_pass.py
0 → 100644
浏览文件 @
6c692b26
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
operator
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
List
from
...
import
functional
as
F
from
...
import
module
as
M
from
...logger
import
get_logger
from
...tensor
import
Parameter
,
Tensor
from
...utils.bn_fusion
import
fold_weight_bias
from
..expr
import
Expr
,
is_call_function
from
..utils
import
assign_attr
,
get_subattr
from
.matcher
import
PatternMatcher
from
.pass_base
import
BackwardPass
,
register_pass
from
.pattern
import
ExprPattern
,
any_node
,
is_const
,
is_op
,
is_var
from
.utils
import
get_const_value
,
register_obj
logger
=
get_logger
(
__name__
)
@
register_pass
(
"FuseAddMul"
)
class
FuseAddMul
(
BackwardPass
):
"""Fold adjacent const add or mul binary operations.
For example, the following code
.. code-block::
x = x + 1
x = 2 + x
x = x * 4
x = x * 0.25
will be changed to
.. code-block::
x = x + 3
"""
name
=
"FuseAddMul"
required_pass
=
[
"NormElemWise"
]
run_once
=
False
def
__init__
(
self
,):
super
().
__init__
()
def
_make_pattern
(
op_0
,
op_1
)
->
ExprPattern
:
x
=
is_var
().
check_users
(
False
)
if
op_0
not
in
[
operator
.
add
,
operator
.
mul
]:
op_0
=
is_op
(
op_0
)
if
op_1
not
in
[
operator
.
add
,
operator
.
mul
]:
op_1
=
is_op
(
op_1
)
pattern
=
op_0
(
x
,
is_const
())
|
op_0
(
x
,
"*"
)
pattern
=
op_1
(
pattern
,
is_const
())
|
op_1
(
pattern
,
"*"
)
return
pattern
self
.
pattern_dict
=
{}
for
op
,
func
in
zip
([
operator
.
add
,
F
.
pow
],
[
self
.
fold_add
,
self
.
fold_pow
],):
self
.
pattern_dict
[
_make_pattern
(
op
,
op
)]
=
func
for
op_0
in
[
F
.
neg
,
operator
.
mul
]:
for
op_1
in
[
F
.
neg
,
operator
.
mul
]:
self
.
pattern_dict
[
_make_pattern
(
op_0
,
op_1
)]
=
self
.
fold_mul
def
run_transform
(
self
,
expr
:
Expr
):
matcher
=
PatternMatcher
()
for
pattern
,
func
in
self
.
pattern_dict
.
items
():
res
=
matcher
.
match
(
pattern
,
expr
)
if
res
:
break
if
not
res
:
return
expr
return
func
(
expr
)
def
_fold_helper
(
self
,
expr
:
Expr
,
op_c
:
Callable
,
op_t
:
Callable
):
const_0
=
self
.
get_const_value
(
expr
)
# todo: support more shape
if
isinstance
(
const_0
,
Tensor
)
and
const_0
.
_tuple_shape
not
in
[(
1
,),
tuple
()]:
return
expr
const_1
=
self
.
get_const_value
(
expr
.
inputs
[
0
].
expr
)
if
isinstance
(
const_1
,
Tensor
)
and
const_1
.
_tuple_shape
not
in
[(
1
,),
tuple
()]:
return
expr
inp_node
=
expr
.
inputs
[
0
].
expr
.
inputs
[
0
]
const
=
op_c
(
const_0
,
const_1
)
graph
=
expr
.
top_graph
if
(
const
==
1
and
op_t
in
[
operator
.
pow
,
operator
.
mul
])
or
(
const
==
0
and
op_t
in
[
operator
.
add
]
):
graph
.
replace_node
({
expr
.
outputs
[
0
]:
inp_node
})
graph
.
compile
()
return
expr
with
expr
.
top_graph
.
insert_exprs
():
out_node
=
op_t
(
inp_node
,
const
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
out_node
})
graph
.
compile
()
return
out_node
.
expr
def
fold_add
(
self
,
expr
:
Expr
):
return
self
.
_fold_helper
(
expr
,
operator
.
add
,
operator
.
add
)
def
fold_mul
(
self
,
expr
):
return
self
.
_fold_helper
(
expr
,
operator
.
mul
,
operator
.
mul
)
def
fold_pow
(
self
,
expr
):
return
self
.
_fold_helper
(
expr
,
operator
.
mul
,
F
.
pow
)
def
get_const_value
(
self
,
expr
:
Expr
):
if
is_call_function
(
expr
,
F
.
neg
):
return
-
1
if
len
(
expr
.
inputs
)
==
2
:
value
=
get_const_value
(
expr
.
inputs
[
1
].
expr
,
None
)
assert
value
is
not
None
,
" "
return
value
value
=
expr
.
const_val
[
0
][
-
1
]
return
value
@
register_pass
(
"FuseConvBn"
)
class
FuseConvBn
(
BackwardPass
):
r
"""Fuse BN layers into conv2d."""
name
=
"FuseConvBn"
required_pass
=
[
"AttrToConstant"
]
run_once
=
True
def
__init__
(
self
):
super
().
__init__
()
self
.
used_name
=
defaultdict
(
int
)
def
run_transform
(
self
,
expr
:
Expr
):
conv_pat_0
=
is_op
(
M
.
Conv2d
)
conv_pat_1
=
is_op
(
F
.
conv2d
)
bn_pat_0
=
is_op
(
M
.
BatchNorm2d
)(
conv_pat_0
|
conv_pat_1
)
bn_pat_1
=
is_op
(
F
.
batch_norm
)
# inp, running_mean, running_var, weight, bias
bn_inps
=
(
conv_pat_0
|
conv_pat_1
,
is_const
(),
is_const
(),
is_const
(),
is_const
(),
)
bn_pat
=
(
(
bn_pat_1
(
*
bn_inps
[:
3
]))
|
(
bn_pat_1
(
*
bn_inps
[:
4
]))
|
(
bn_pat_1
(
*
bn_inps
))
|
bn_pat_0
)
matcher
=
PatternMatcher
()
if
not
matcher
.
match
(
bn_pat
,
expr
):
return
expr
matched_exprs
=
matcher
.
matched_exprs
if
conv_pat_0
in
matched_exprs
:
return
self
.
fold_convm_bn
(
matched_exprs
[
conv_pat_0
],
matched_exprs
[
bn_pat
])
else
:
return
self
.
fold_convf_bn
(
matched_exprs
[
conv_pat_1
],
matched_exprs
[
bn_pat
])
def
fold_convm_bn
(
self
,
conv
:
Expr
,
bn
:
Expr
):
mnode
,
inp_node
=
conv
.
inputs
[:
2
]
self_node
=
mnode
.
expr
.
inputs
[
0
]
attr_name
=
conv
.
inputs
[
0
].
expr
.
name
graph
=
conv
.
top_graph
if
len
(
mnode
.
users
)
>
1
:
self
.
used_name
[
mnode
.
qualname
]
+=
1
attr_name
=
"{}_{}"
.
format
(
attr_name
,
self
.
used_name
[
mnode
.
qualname
])
logger
.
warning
(
"{} is used {} times and its name will be reset to {}.{}"
.
format
(
mnode
.
qualname
,
len
(
mnode
.
users
),
graph
.
qualname
,
attr_name
)
)
conv_module
=
mnode
.
owner
weight
,
bias
=
conv_module
.
weight
,
conv_module
.
bias
mean
,
var
,
gamma
,
beta
,
eps
=
self
.
get_bn_params
(
bn
)
weight
,
bias
=
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
mean
,
var
,
eps
)
new_conv
=
M
.
Conv2d
(
in_channels
=
conv_module
.
in_channels
,
out_channels
=
conv_module
.
out_channels
,
kernel_size
=
conv_module
.
kernel_size
,
stride
=
conv_module
.
stride
,
padding
=
conv_module
.
padding
,
dilation
=
conv_module
.
dilation
,
groups
=
conv_module
.
groups
,
bias
=
conv_module
.
bias
is
not
None
,
conv_mode
=
conv_module
.
conv_mode
,
compute_mode
=
conv_module
.
compute_mode
,
name
=
conv_module
.
name
,
)
new_conv
.
weight
=
Parameter
(
weight
)
new_conv
.
bias
=
Parameter
(
bias
)
new_conv
.
training
=
conv_module
.
training
assign_attr
(
new_conv
,
self_node
.
owner
,
attr_name
)
with
graph
.
insert_exprs
(
mnode
.
expr
):
out_node
=
get_subattr
(
self_node
,
attr_name
)(
inp_node
)
graph
.
replace_node
({
bn
.
outputs
[
0
]:
out_node
})
graph
.
compile
()
out_node
.
name
=
conv
.
outputs
[
0
].
name
return
out_node
.
expr
def
fold_convf_bn
(
self
,
conv
:
Expr
,
bn
:
Expr
):
named_args
=
conv
.
named_args
weight
=
get_const_value
(
named_args
[
"weight"
],
named_args
[
"weight"
])
bias
=
get_const_value
(
named_args
[
"bias"
],
named_args
[
"bias"
])
mean
,
var
,
gamma
,
beta
,
eps
=
self
.
get_bn_params
(
bn
)
weight
,
bias
=
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
mean
,
var
,
eps
)
named_args
[
"weight"
]
=
weight
named_args
[
"bias"
]
=
bias
graph
=
conv
.
top_graph
with
graph
.
insert_exprs
():
out_node
=
F
.
conv2d
(
**
named_args
)
graph
.
replace_node
({
bn
.
outputs
[
0
]:
out_node
})
graph
.
compile
()
out_node
.
name
=
conv
.
outputs
[
0
].
name
return
out_node
.
expr
def
get_bn_params
(
self
,
bn
:
Expr
):
if
is_call_function
(
bn
):
named_args
=
bn
.
named_args
mean
=
get_const_value
(
named_args
[
"running_mean"
],
named_args
[
"running_mean"
]
)
var
=
get_const_value
(
named_args
[
"running_var"
],
named_args
[
"running_var"
])
gamma
=
get_const_value
(
named_args
[
"weight"
],
named_args
[
"weight"
])
beta
=
get_const_value
(
named_args
[
"bias"
],
named_args
[
"bias"
])
eps
=
named_args
[
"eps"
]
return
mean
,
var
,
gamma
,
beta
,
eps
else
:
bn_module
=
bn
.
inputs
[
0
].
owner
mean
=
bn_module
.
running_mean
var
=
bn_module
.
running_var
gamma
=
bn_module
.
weight
beta
=
bn_module
.
bias
eps
=
bn_module
.
eps
return
mean
,
var
,
gamma
,
beta
,
eps
imperative/python/megengine/traced_module/_passes/pass_base.py
0 → 100644
浏览文件 @
6c692b26
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
copy
from
abc
import
abstractmethod
from
collections
import
OrderedDict
,
namedtuple
from
functools
import
partial
from
re
import
T
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Union
from
...logger
import
get_logger
from
..expr
import
Expr
from
..traced_module
import
InternalGraph
,
TracedModule
from
.utils
import
register_obj
logger
=
get_logger
(
__name__
)
class
PassContext
:
def
__init__
(
self
,
disabled_pass
:
Iterable
[
str
]
=
None
,
pass_config
:
Dict
[
str
,
Any
]
=
None
):
self
.
_disabled_pass
=
set
()
self
.
_config
=
pass_config
self
.
_handle
=
None
if
disabled_pass
:
self
.
add_diabled_pass
(
disabled_pass
)
def
add_diabled_pass
(
self
,
passes
:
Iterable
[
str
]):
if
isinstance
(
passes
,
str
):
passes
=
[
passes
]
for
pas
in
passes
:
self
.
_disabled_pass
.
add
(
pas
)
def
pass_enabled
(
self
,
pas
:
Union
[
"BasePass"
,
str
]):
pass_name
=
pas
.
name
if
isinstance
(
pas
,
BasePass
)
else
pas
return
pass_name
not
in
self
.
_disabled_pass
_default_context
=
PassContext
()
def
get_default_pass_context
():
return
_default_context
_pass_dict
=
OrderedDict
()
register_pass
=
partial
(
register_obj
,
_dict
=
_pass_dict
)
def
get_registered_pass
(
pass_name
:
str
):
pas
=
_pass_dict
.
get
(
pass_name
,
None
)
assert
(
pas
is
not
None
),
"{} is not found, please call `register_pass` to register it"
.
format
(
pass_name
)
return
pas
class
BasePass
:
run_once
=
True
# bool
required_pass
=
[]
# Iterable[str]
name
=
""
# str
def
__init__
(
self
):
super
().
__init__
()
def
__call__
(
self
,
mod
:
TracedModule
,
pass_ctx
:
PassContext
=
get_default_pass_context
()
)
->
TracedModule
:
assert
isinstance
(
pass_ctx
,
PassContext
)
return
self
.
apply_optimization
(
mod
,
pass_ctx
)
def
apply_optimization
(
self
,
mod
:
TracedModule
,
pass_ctx
:
PassContext
)
->
TracedModule
:
new_mod
=
mod
for
pass_name
in
self
.
required_pass
+
[
self
.
name
]:
if
not
pass_ctx
.
pass_enabled
(
pass_name
):
logger
.
warning
(
"Since {} is disabled, {} will skipped"
.
format
(
pass_name
,
self
.
name
)
)
return
mod
for
pass_name
in
self
.
required_pass
:
pass_func
=
get_registered_pass
(
pass_name
)()
new_mod
=
pass_func
(
new_mod
,
pass_ctx
)
iter_num
=
1
graph_changed
=
self
.
visit_graph
(
new_mod
.
graph
)
while
not
self
.
run_once
and
graph_changed
:
graph_changed
=
self
.
visit_graph
(
new_mod
.
graph
)
iter_num
+=
1
if
iter_num
==
100
:
break
assert
iter_num
<
100
,
"{} was run 100 times, plase check for pass conflict."
return
new_mod
@
abstractmethod
def
visit_graph
(
self
,
graph
:
InternalGraph
):
raise
NotImplementedError
def
before_visit_graph
(
self
,
graph
:
InternalGraph
):
pass
def
run_transform
(
self
,
expr
:
Expr
)
->
Expr
:
return
expr
def
__repr__
(
self
)
->
str
:
return
self
.
name
class
ForwardPass
(
BasePass
):
def
visit_graph
(
self
,
graph
:
InternalGraph
):
class
Item
:
def
__init__
(
self
,
expr
:
Expr
,
child_expanded
:
bool
=
False
):
self
.
expr
=
expr
self
.
child_expanded
=
child_expanded
self
.
before_visit_graph
(
graph
)
graph_changed
=
False
queue
=
[
Item
(
n
.
expr
)
for
n
in
graph
.
outputs
]
visited_expr
,
visited_graph
=
set
(),
set
()
while
queue
:
item
=
queue
[
-
1
]
if
item
.
expr
in
visited_expr
:
queue
.
pop
()
elif
item
.
child_expanded
:
if
item
.
expr
not
in
graph
.
_exprs
:
queue
.
pop
()
continue
new_expr
=
self
.
run_transform
(
item
.
expr
)
if
new_expr
is
not
item
.
expr
:
graph_changed
=
True
assert
new_expr
not
in
visited_expr
queue
.
append
(
Item
(
new_expr
))
continue
if
(
hasattr
(
item
.
expr
,
"graph"
)
and
item
.
expr
.
graph
is
not
None
and
item
.
expr
.
graph
not
in
visited_graph
):
graph_changed
|=
self
.
visit_graph
(
item
.
expr
.
graph
)
visited_graph
.
add
(
item
.
expr
.
graph
)
visited_expr
.
add
(
item
.
expr
)
else
:
item
.
child_expanded
=
True
for
i
in
item
.
expr
.
inputs
:
expr
=
i
.
expr
if
expr
not
in
queue
and
expr
not
in
visited_expr
:
queue
.
append
(
Item
(
expr
))
return
graph_changed
class
BackwardPass
(
BasePass
):
def
visit_graph
(
self
,
graph
:
InternalGraph
):
self
.
before_visit_graph
(
graph
)
graph_changed
=
False
queue
=
[
n
.
expr
for
n
in
graph
.
outputs
]
visited_expr
,
visited_graph
=
set
(),
set
()
while
queue
:
expr
=
queue
.
pop
()
if
expr
not
in
graph
.
_exprs
:
continue
new_expr
=
self
.
run_transform
(
expr
)
if
new_expr
is
not
expr
:
graph_changed
=
True
queue
.
append
(
new_expr
)
continue
else
:
visited_expr
.
add
(
expr
)
if
(
hasattr
(
expr
,
"graph"
)
and
expr
.
graph
is
not
None
and
expr
.
graph
not
in
visited_graph
):
graph_changed
|=
self
.
visit_graph
(
expr
.
graph
)
visited_graph
.
add
(
expr
.
graph
)
for
i
in
expr
.
inputs
:
expr
=
i
.
expr
if
expr
not
in
queue
and
expr
not
in
visited_expr
:
queue
.
append
(
expr
)
return
graph_changed
imperative/python/megengine/traced_module/expr.py
浏览文件 @
6c692b26
...
@@ -13,7 +13,7 @@ import inspect
...
@@ -13,7 +13,7 @@ import inspect
import
re
import
re
import
weakref
import
weakref
from
importlib
import
import_module
from
importlib
import
import_module
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tupl
e
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequenc
e
,
Union
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
...
@@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str):
...
@@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str):
return
matchd
.
group
(
1
)
return
matchd
.
group
(
1
)
def
is_call_module
(
expr
):
def
is_call_module
(
expr
,
module_cls
:
Module
=
None
):
return
(
return
(
isinstance
(
expr
,
CallMethod
)
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
)
and
expr
.
method
==
"__call__"
and
expr
.
method
==
"__call__"
)
)
and
(
module_cls
is
None
or
isinstance
(
expr
.
inputs
[
0
].
owner
,
module_cls
))
def
is_call_tensor_method
(
expr
):
def
is_call_tensor_method
(
expr
,
method
:
Iterable
[
str
]
=
None
):
return
isinstance
(
expr
,
CallMethod
)
and
not
is_call_module
(
expr
)
if
method
and
isinstance
(
method
,
str
):
method
=
(
method
,)
return
(
isinstance
(
expr
,
CallMethod
)
and
not
is_call_module
(
expr
)
and
(
method
is
None
or
any
(
expr
.
method
==
f
for
f
in
method
))
)
def
is_call_function
(
expr
):
def
is_call_function
(
expr
,
func
:
Iterable
[
Callable
]
=
None
):
return
isinstance
(
expr
,
CallFunction
)
if
func
and
not
isinstance
(
func
,
Iterable
):
func
=
(
func
,)
return
isinstance
(
expr
,
CallFunction
)
and
(
func
is
None
or
any
(
expr
.
func
==
f
for
f
in
func
)
)
def
is_constant
(
expr
):
def
is_constant
(
expr
):
...
@@ -74,8 +84,8 @@ def is_getattr(expr):
...
@@ -74,8 +84,8 @@ def is_getattr(expr):
return
isinstance
(
expr
,
GetAttr
)
return
isinstance
(
expr
,
GetAttr
)
def
is_apply_def
(
expr
):
def
is_apply_def
(
expr
,
opdef
=
None
):
return
isinstance
(
expr
,
Apply
)
return
isinstance
(
expr
,
Apply
)
and
(
opdef
is
None
or
isinstance
(
expr
.
opdef
,
opdef
))
def
is_input
(
expr
):
def
is_input
(
expr
):
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
6c692b26
...
@@ -78,6 +78,7 @@ class Node:
...
@@ -78,6 +78,7 @@ class Node:
"The name(%s) is already in use. Please try a different one again."
"The name(%s) is already in use. Please try a different one again."
%
(
new_name
)
%
(
new_name
)
)
)
graph
.
_namespace
.
unassociate_name_with_obj
(
self
)
self
.
_name
=
graph
.
_namespace
.
create_unique_name
(
new_name
,
self
)
self
.
_name
=
graph
.
_namespace
.
create_unique_name
(
new_name
,
self
)
@
property
@
property
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
6c692b26
...
@@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni
...
@@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni
from
..
import
get_logger
from
..
import
get_logger
from
..module
import
Module
from
..module
import
Module
from
..tensor
import
Parameter
,
Tensor
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping):
...
@@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping):
def
forward
(
self
):
def
forward
(
self
):
raise
RuntimeError
(
"ModuleList is not callable"
)
raise
RuntimeError
(
"ModuleList is not callable"
)
def
assign_attr
(
obj
:
Union
[
Module
,
Tensor
],
module
:
Module
,
target
:
str
):
*
prefix
,
name
=
target
.
split
(
"."
)
for
item
in
prefix
:
module
=
getattr
(
module
,
item
)
if
not
isinstance
(
module
,
Module
):
raise
AttributeError
(
"`{}` is not an Module"
.
format
(
item
))
setattr
(
module
,
name
,
obj
)
def
get_subattr
(
module
:
Module
,
target
:
str
):
# todo : remove this import
from
.node
import
ModuleNode
if
target
==
""
:
return
module
*
prefix
,
name
=
target
.
split
(
"."
)
for
item
in
prefix
:
module
=
getattr
(
module
,
item
)
if
not
isinstance
(
module
,
(
Module
,
ModuleNode
)):
raise
AttributeError
(
"`{}` is not an Module"
.
format
(
item
))
return
getattr
(
module
,
name
)
imperative/python/megengine/utils/bn_fusion.py
0 → 100644
浏览文件 @
6c692b26
from
copy
import
deepcopy
from
..functional
import
ones
,
sqrt
,
zeros
from
..module
import
BatchNorm2d
,
Conv2d
,
ConvBn2d
,
ConvBnRelu2d
,
ConvRelu2d
,
ReLU
from
..tensor
import
Parameter
_MAP_TO_FUSED_MODULE
=
{
(
Conv2d
,
BatchNorm2d
,
ReLU
,
False
):
ConvRelu2d
,
(
Conv2d
,
BatchNorm2d
,
ReLU
,
True
):
ConvBnRelu2d
,
(
Conv2d
,
BatchNorm2d
,
False
):
Conv2d
,
(
Conv2d
,
BatchNorm2d
,
True
):
ConvBn2d
,
(
Conv2d
,
ReLU
):
ConvRelu2d
,
}
def
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
):
# get fold bn conv param
kernel_shape
=
weight
.
shape
if
len
(
kernel_shape
)
==
5
:
groups
,
num_features
=
kernel_shape
[
0
],
kernel_shape
[
1
]
else
:
groups
,
num_features
=
1
,
kernel_shape
[
0
]
if
gamma
is
None
:
gamma
=
ones
((
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
if
beta
is
None
:
beta
=
zeros
((
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
1
,
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bn_var
is
None
:
bn_var
=
ones
((
1
,
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bias
is
None
:
bias
=
zeros
((
1
,
num_features
,
1
,
1
),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
eps
)
scale_factor
=
gamma
*
bn_istd
if
groups
==
1
:
w_fold
=
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
weight
*
scale_factor
.
reshape
(
groups
,
-
1
,
1
,
1
,
1
)
b_fold
=
beta
+
gamma
*
(
bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
fuse_conv_bn_relu_module
(
conv
:
Conv2d
,
bn
:
BatchNorm2d
,
relu
:
ReLU
):
module_key
=
tuple
([
type
(
m
)
for
m
in
[
conv
,
bn
,
relu
]
if
m
])
if
bn
:
assert
(
conv
.
training
==
bn
.
training
),
"Conv and BN both must be in the same mode (train or eval)."
assert
(
bn
.
num_features
==
conv
.
out_channels
),
"Output channel of Conv2d must match num_features of BatchNorm2d"
module_key
=
module_key
+
(
conv
.
training
,)
module
=
_MAP_TO_FUSED_MODULE
[
module_key
](
in_channels
=
conv
.
in_channels
,
out_channels
=
conv
.
out_channels
,
kernel_size
=
conv
.
kernel_size
,
stride
=
conv
.
stride
,
padding
=
conv
.
padding
,
dilation
=
conv
.
dilation
,
groups
=
conv
.
groups
,
bias
=
conv
.
bias
is
not
None
,
conv_mode
=
conv
.
conv_mode
,
compute_mode
=
conv
.
compute_mode
,
name
=
conv
.
name
,
)
new_conv
=
module
if
bn
is
None
or
not
conv
.
training
else
module
.
conv
weight
,
bias
=
conv
.
weight
,
conv
.
bias
if
not
conv
.
training
and
bn
is
not
None
:
weight
,
bias
=
fold_weight_bias
(
weight
,
bias
,
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
bn
.
eps
,
)
new_conv
.
weight
=
Parameter
(
weight
)
if
bias
is
not
None
:
new_conv
.
bias
=
Parameter
(
bias
)
if
bn
is
not
None
and
conv
.
training
:
module
.
bn
=
deepcopy
(
bn
)
new_conv
.
training
=
conv
.
training
return
module
imperative/python/test/unit/traced_module/test_qat_module.py
浏览文件 @
6c692b26
...
@@ -13,20 +13,7 @@ import megengine.quantization as Q
...
@@ -13,20 +13,7 @@ import megengine.quantization as Q
from
megengine
import
Tensor
from
megengine
import
Tensor
from
megengine.module.qat.module
import
QATModule
from
megengine.module.qat.module
import
QATModule
from
megengine.traced_module
import
TracedModule
,
trace_module
from
megengine.traced_module
import
TracedModule
,
trace_module
from
megengine.traced_module.utils
import
get_subattr
def
get_subattr
(
self
:
M
.
Module
,
name
:
str
):
if
name
==
""
:
return
self
module_path
,
_
,
name
=
name
.
rpartition
(
"."
)
if
module_path
==
""
:
return
getattr
(
self
,
name
)
module_names
=
module_path
.
split
(
"."
)
for
item
in
module_names
:
self
=
getattr
(
self
,
item
)
if
not
isinstance
(
self
,
M
.
Module
):
raise
AttributeError
(
"`{}` is not an Module"
.
format
(
item
))
return
getattr
(
self
,
name
)
class
MyConvBnRelu2d
(
M
.
ConvBnRelu2d
):
class
MyConvBnRelu2d
(
M
.
ConvBnRelu2d
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录