Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fc212042
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fc212042
编写于
1月 27, 2022
作者:
M
Megvii Engine Team
提交者:
wenjuan
2月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(traced_module): fix Module compatible issue and traced module getattr check
GitOrigin-RevId: 62eb3bfb10e8fda942c84a6ce69acaebc85228dc
上级
275b6311
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
52 addition
and
10 deletion
+52
-10
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+7
-5
imperative/python/megengine/traced_module/serialization.py
imperative/python/megengine/traced_module/serialization.py
+1
-1
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+6
-4
imperative/python/test/unit/core/test_serialization.py
imperative/python/test/unit/core/test_serialization.py
+14
-0
imperative/python/test/unit/module/test_module.py
imperative/python/test/unit/module/test_module.py
+24
-0
未找到文件。
imperative/python/megengine/module/module.py
浏览文件 @
fc212042
...
@@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta):
...
@@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta):
return
HookHandler
(
self
.
_forward_hooks
,
hook
)
return
HookHandler
(
self
.
_forward_hooks
,
hook
)
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
AutoNaming
.
push_scope
(
AutoNaming
.
push_scope
(
self
.
name
if
self
.
name
is
not
None
else
self
.
_short_name
)
self
.
name
if
self
.
name
is
not
None
else
(
self
.
_short_name
if
hasattr
(
self
,
"_short_name"
)
else
self
.
_name
)
)
for
hook
in
self
.
_forward_pre_hooks
.
values
():
for
hook
in
self
.
_forward_pre_hooks
.
values
():
modified_inputs
=
hook
(
self
,
inputs
)
modified_inputs
=
hook
(
self
,
inputs
)
if
modified_inputs
is
not
None
:
if
modified_inputs
is
not
None
:
...
@@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta):
...
@@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta):
set_name
(
self
,
prefix
,
k
,
v
)
set_name
(
self
,
prefix
,
k
,
v
)
super
().
__setattr__
(
name
,
value
)
super
().
__setattr__
(
name
,
value
)
def
__setstate__
(
self
,
state
):
if
"_short_name"
not
in
state
:
state
[
"_short_name"
]
=
state
[
"_name"
]
state
[
"_name"
]
=
None
self
.
__dict__
.
update
(
state
)
def
__delattr__
(
self
,
name
:
str
):
def
__delattr__
(
self
,
name
:
str
):
if
name
in
self
.
__dict__
and
_is_module
(
self
.
__dict__
[
name
]):
if
name
in
self
.
__dict__
and
_is_module
(
self
.
__dict__
[
name
]):
modules
=
self
.
__dict__
.
get
(
"_modules"
)
modules
=
self
.
__dict__
.
get
(
"_modules"
)
...
...
imperative/python/megengine/traced_module/serialization.py
浏览文件 @
fc212042
...
@@ -50,7 +50,7 @@ class _ModuleState:
...
@@ -50,7 +50,7 @@ class _ModuleState:
if
self
.
obj
is
None
:
if
self
.
obj
is
None
:
typem
=
getattr
(
import_module
(
self
.
module
[
0
]),
self
.
module
[
1
])
typem
=
getattr
(
import_module
(
self
.
module
[
0
]),
self
.
module
[
1
])
m_obj
=
typem
.
__new__
(
typem
)
m_obj
=
typem
.
__new__
(
typem
)
m_obj
.
__
dict__
.
update
(
self
.
state
)
m_obj
.
__
setstate__
(
self
.
state
)
self
.
obj
=
m_obj
self
.
obj
=
m_obj
return
self
.
obj
return
self
.
obj
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
fc212042
...
@@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin):
if
isinstance
(
wrapped
,
TracedModuleBuilder
):
if
isinstance
(
wrapped
,
TracedModuleBuilder
):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
,
QATModule
)):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
,
QATModule
)):
assert
mod_attr
is
wrapped
.
_mod
assert
(
else
:
mod_attr
is
wrapped
.
_mod
),
"TracedModule do not support modify module attributes, please check your code."
if
isinstance
(
wrapped
,
RawTensor
):
assert
(
assert
(
mod_attr
is
wrapped
mod_attr
is
wrapped
),
"TracedModule do not support modify attributes, please check your code."
),
"TracedModule do not support modify
tensor
attributes, please check your code."
if
isinstance
(
wrapped
,
(
NodeMixin
,
RawTensor
)):
if
isinstance
(
wrapped
,
(
NodeMixin
,
RawTensor
)):
NodeMixin
.
wrap
(
NodeMixin
.
wrap
(
...
@@ -2296,7 +2298,7 @@ class TracedModule(Module):
...
@@ -2296,7 +2298,7 @@ class TracedModule(Module):
for
k
,
v
in
state
.
items
():
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
_ModuleState
):
if
isinstance
(
v
,
_ModuleState
):
state
[
k
]
=
v
.
to_module
()
state
[
k
]
=
v
.
to_module
()
s
elf
.
__dict__
.
update
(
state
)
s
uper
().
__setstate__
(
state
)
self
.
_update_ref
()
self
.
_update_ref
()
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
...
...
imperative/python/test/unit/core/test_serialization.py
浏览文件 @
fc212042
...
@@ -87,3 +87,17 @@ def test_compatibility():
...
@@ -87,3 +87,17 @@ def test_compatibility():
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
t
=
mge
.
tensor
([
1
])
getattr
(
t
,
"qparams"
)
new_args
=
t
.
__getnewargs__
()
assert
(
len
(
new_args
)
==
3
and
isinstance
(
new_args
[
0
],
np
.
ndarray
)
and
new_args
[
1
]
==
np
.
int32
and
isinstance
(
new_args
[
2
],
str
)
),
"Modify Tensor __getnewargs__ may break pickle serialization compatible"
state
=
t
.
__getstate__
()
assert
set
(
state
.
keys
())
==
set
(
[
"qparams"
]
),
"Modify Tensor __getstate__ may break pickle serialization compatible"
imperative/python/test/unit/module/test_module.py
浏览文件 @
fc212042
...
@@ -681,3 +681,27 @@ def test_repr_module_reset_attr():
...
@@ -681,3 +681,27 @@ def test_repr_module_reset_attr():
m1
=
ResetAttrModule
(
False
)
m1
=
ResetAttrModule
(
False
)
output
=
[
m0
.
__repr__
(),
m1
.
__repr__
()]
output
=
[
m0
.
__repr__
(),
m1
.
__repr__
()]
assert
output
==
ground_truth
assert
output
==
ground_truth
def
test_module_compatible
():
class
Empty
(
Module
):
def
forward
(
self
):
pass
empty_module
=
Empty
()
old_attributes
=
set
(
[
"_modules"
,
"name"
,
"training"
,
"quantize_disabled"
,
"_forward_pre_hooks"
,
"_forward_hooks"
,
"_name"
,
"_short_name"
,
]
)
current_attributes
=
set
(
empty_module
.
__dict__
.
keys
())
assert
(
old_attributes
==
current_attributes
),
"Add or delete attributes in Module class may break compatibility of pickle serialization"
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录