Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e6dcfbe8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e6dcfbe8
编写于
6月 10, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(traced_module): fix traced module compatible issues
GitOrigin-RevId: 67e68ef5eae78d93a167d8d32ac78837932f3b45
上级
18f83a25
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
30 addition
and
12 deletion
+30
-12
imperative/python/megengine/traced_module/compat.py
imperative/python/megengine/traced_module/compat.py
+27
-10
imperative/python/megengine/utils/deprecation.py
imperative/python/megengine/utils/deprecation.py
+3
-2
未找到文件。
imperative/python/megengine/traced_module/compat.py
浏览文件 @
e6dcfbe8
...
@@ -99,11 +99,10 @@ def add_loader(expr):
...
@@ -99,11 +99,10 @@ def add_loader(expr):
(
"megengine.module.batchnorm"
,
"SyncBatchNorm"
),
(
"megengine.module.batchnorm"
,
"SyncBatchNorm"
),
)
)
def
bn2d_module_loader
(
expr
):
def
bn2d_module_loader
(
expr
):
# mge 1.6
module
=
expr
.
inputs
[
0
].
owner
if
not
hasattr
(
expr
,
"version"
):
if
hasattr
(
module
,
"param_dim"
):
module
=
expr
.
inputs
[
0
].
owner
assert
module
.
param_dim
==
"dim_1c11"
if
not
hasattr
(
module
,
"param_dim"
):
delattr
(
module
,
"param_dim"
)
module
.
param_dim
=
"dim_1c11"
@
register_module_loader
(
@
register_module_loader
(
...
@@ -113,12 +112,10 @@ def bn2d_module_loader(expr):
...
@@ -113,12 +112,10 @@ def bn2d_module_loader(expr):
(
"megengine.module.qat.conv_bn"
,
"ConvBnRelu2d"
),
(
"megengine.module.qat.conv_bn"
,
"ConvBnRelu2d"
),
)
)
def
convbn2d_module_loader
(
expr
):
def
convbn2d_module_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
module
=
expr
.
inputs
[
0
].
owner
if
not
hasattr
(
module
.
bn
,
"param_dim"
):
module
.
bn
.
param_dim
=
"dim_1c11"
module
=
expr
.
inputs
[
0
].
owner
module
=
expr
.
inputs
[
0
].
owner
if
hasattr
(
module
.
bn
,
"param_dim"
):
assert
module
.
bn
.
param_dim
==
"dim_1c11"
delattr
(
module
.
bn
,
"param_dim"
)
if
not
hasattr
(
module
.
conv
,
"padding_mode"
):
if
not
hasattr
(
module
.
conv
,
"padding_mode"
):
module
.
conv
.
padding_mode
=
"zeros"
module
.
conv
.
padding_mode
=
"zeros"
...
@@ -167,6 +164,26 @@ def pad_func_loader(expr):
...
@@ -167,6 +164,26 @@ def pad_func_loader(expr):
expr
.
set_args_kwargs
(
*
expr
.
args
,
**
kwargs
)
expr
.
set_args_kwargs
(
*
expr
.
args
,
**
kwargs
)
@
register_functional_loader
((
"megengine.functional.nn"
,
"batch_norm"
))
def
bn_func_loader
(
expr
):
kwargs
=
expr
.
kwargs
if
"compute_mode"
in
kwargs
:
assert
kwargs
[
"compute_mode"
]
==
"default"
kwargs
.
pop
(
"compute_mode"
)
if
"param_dim"
in
kwargs
:
assert
kwargs
[
"param_dim"
]
==
"dim_1c11"
kwargs
.
pop
(
"param_dim"
)
expr
.
set_args_kwargs
(
*
expr
.
args
,
**
kwargs
)
@
register_functional_loader
((
"megengine.functional.math"
,
"matmul"
))
def
matmul_func_loader
(
expr
):
args
=
expr
.
args
if
len
(
args
)
==
6
:
assert
args
[
5
]
==
"default"
expr
.
set_args_kwargs
(
*
args
[
0
:
5
])
@
register_module_loader
(
@
register_module_loader
(
(
"megengine.module.conv"
,
"Conv1d"
),
(
"megengine.module.conv"
,
"Conv1d"
),
(
"megengine.module.conv"
,
"Conv2d"
),
(
"megengine.module.conv"
,
"Conv2d"
),
...
...
imperative/python/megengine/utils/deprecation.py
浏览文件 @
e6dcfbe8
...
@@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd):
...
@@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd):
tbd: to be discussed, if true, ignore warnings
tbd: to be discussed, if true, ignore warnings
"""
"""
should_warning
=
not
tbd
should_warning
=
not
tbd
module
=
importlib
.
import_module
(
origin
)
func
=
module
.
__getattribute__
(
name
)
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
should_warning
nonlocal
should_warning
module
=
importlib
.
import_module
(
origin
)
func
=
module
.
__getattribute__
(
name
)
if
should_warning
:
if
should_warning
:
warnings
.
warn
(
warnings
.
warn
(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}."
.
format
(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}."
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录