Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5a38ad39
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5a38ad39
编写于
3月 11, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/utils): add get/set_expand_structure to deal with complex key
GitOrigin-RevId: 4d1b952068ffda21189f315ad70888dee80bc65f
上级
fad5bc74
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
87 addition
and
176 deletion
+87
-176
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+23
-3
imperative/python/megengine/module/sequential.py
imperative/python/megengine/module/sequential.py
+2
-2
imperative/python/megengine/quantization/quantize.py
imperative/python/megengine/quantization/quantize.py
+4
-11
imperative/python/megengine/utils/module_utils.py
imperative/python/megengine/utils/module_utils.py
+43
-0
imperative/python/test/unit/module/test_module.py
imperative/python/test/unit/module/test_module.py
+15
-160
未找到文件。
imperative/python/megengine/module/module.py
浏览文件 @
5a38ad39
...
...
@@ -21,9 +21,9 @@ from ..utils.naming import auto_naming
logger
=
get_logger
(
__name__
)
def
_expand_structure
(
key
,
obj
):
def
_expand_structure
(
prefix
,
obj
):
if
isinstance
(
obj
,
(
Tensor
,
Module
)):
return
[(
key
,
obj
)]
return
[(
prefix
,
obj
)]
elif
isinstance
(
obj
,
(
list
,
tuple
,
dict
)):
ret
=
[]
if
isinstance
(
obj
,
dict
):
...
...
@@ -37,12 +37,32 @@ def _expand_structure(key, obj):
"keys for Tensor and Module must be str, error key: {}"
.
format
(
k
)
)
for
kt
,
vt
in
sub_ret
:
ret
.
extend
([(
key
+
"."
+
kt
,
vt
)])
ret
.
extend
([(
prefix
+
"."
+
kt
,
vt
)])
return
ret
else
:
return
[]
def
_access_structure
(
obj
,
key
,
callback
=
None
):
key_list
=
key
.
split
(
"."
)
cur
=
obj
parent
=
None
for
k
in
key_list
:
parent
=
cur
if
isinstance
(
cur
,
(
Tensor
,
Module
)):
cur
=
getattr
(
cur
,
k
)
elif
isinstance
(
cur
,
(
list
,
tuple
)):
k
=
int
(
k
)
cur
=
cur
[
k
]
elif
isinstance
(
cur
,
dict
):
cur
=
cur
[
k
]
else
:
raise
ValueError
(
"Unsupport value type {} to access attribute"
.
format
(
type
(
cur
))
)
return
callback
(
parent
,
k
,
cur
)
def
_is_parameter
(
obj
):
return
isinstance
(
obj
,
Parameter
)
...
...
imperative/python/megengine/module/sequential.py
浏览文件 @
5a38ad39
...
...
@@ -18,9 +18,9 @@ class Sequential(Module):
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example:
Examples:
.. testcode::
import numpy as np
...
...
imperative/python/megengine/quantization/quantize.py
浏览文件 @
5a38ad39
...
...
@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
copy
import
copy
,
deepcopy
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
Tuple
from
typing
import
Callable
import
numpy
as
np
...
...
@@ -19,6 +19,7 @@ from ..module import quantized as Quantized
from
..module.qat
import
QATModule
from
..module.quantized
import
QuantizedModule
from
..tensor
import
Tensor
from
..utils.module_utils
import
set_expand_structure
from
.qconfig
import
QConfig
,
ema_fakequant_qconfig
...
...
@@ -79,11 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None):
module
.
_flatten
(
with_key
=
True
,
with_parent
=
True
,
predicate
=
is_qat
)
):
new_mod
=
convert_dict
[
type
(
submodule
)].
from_qat_module
(
submodule
)
if
isinstance
(
parent
,
Float
.
Sequential
):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent
[
int
(
key
.
split
(
"."
)[
-
1
])]
=
new_mod
else
:
setattr
(
parent
,
key
.
split
(
"."
)[
-
1
],
new_mod
)
set_expand_structure
(
parent
,
key
,
new_mod
)
return
module
...
...
@@ -126,11 +123,7 @@ def quantize_qat(
continue
new_mod
=
convert_dict
[
type
(
submodule
)].
from_float_module
(
submodule
)
if
isinstance
(
parent
,
Float
.
Sequential
):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent
[
int
(
key
.
split
(
"."
)[
-
1
])]
=
new_mod
else
:
setattr
(
parent
,
key
.
split
(
"."
)[
-
1
],
new_mod
)
set_expand_structure
(
parent
,
key
,
new_mod
)
propagate_qconfig
(
module
,
qconfig
)
return
module
...
...
imperative/python/megengine/utils/module_utils.py
0 → 100644
浏览文件 @
5a38ad39
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
collections
import
Iterable
from
..module
import
Sequential
from
..module.module
import
Module
,
_access_structure
from
..tensor
import
Tensor
def
get_expand_structure
(
obj
:
Module
,
key
:
str
):
"""
Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
Supports handling structure containing list or dict.
"""
def
f
(
_
,
__
,
cur
):
return
cur
return
_access_structure
(
obj
,
key
,
callback
=
f
)
def
set_expand_structure
(
obj
:
Module
,
key
:
str
,
value
):
"""
Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
Supports handling structure containing list or dict.
"""
def
f
(
parent
,
key
,
cur
):
if
isinstance
(
parent
,
(
Tensor
,
Module
)):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
if
isinstance
(
cur
,
Sequential
):
parent
[
int
(
key
)]
=
value
else
:
setattr
(
parent
,
key
,
value
)
else
:
parent
[
key
]
=
value
_access_structure
(
obj
,
key
,
callback
=
f
)
imperative/python/test/unit/module/test_module.py
浏览文件 @
5a38ad39
...
...
@@ -6,8 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
import
tempfile
from
collections
import
OrderedDict
from
io
import
BytesIO
...
...
@@ -29,7 +27,9 @@ from megengine.module import (
Sequential
,
Softmax
,
)
from
megengine.module.module
import
_access_structure
from
megengine.quantization.quantize
import
quantize
,
quantize_qat
from
megengine.utils.module_utils
import
get_expand_structure
,
set_expand_structure
class
MLP
(
Module
):
...
...
@@ -45,146 +45,6 @@ class MLP(Module):
return
x
def
has_gpu
(
num
=
1
):
try
:
mgb
.
comp_node
(
"gpu{}"
.
format
(
num
-
1
))
except
mgb
.
MegBrainError
:
return
False
return
True
def
randomNp
(
*
args
):
for
arg
in
args
:
assert
isinstance
(
arg
,
int
)
return
np
.
random
.
random
(
args
)
def
randomTorch
(
*
args
):
import
torch
# pylint: disable=import-outside-toplevel
for
arg
in
args
:
assert
isinstance
(
arg
,
int
)
return
torch
.
tensor
(
randomNp
(
*
args
),
dtype
=
torch
.
float32
)
def
graph_mode
(
*
modes
):
if
not
set
(
modes
).
issubset
({
"eager"
,
"static"
}):
raise
ValueError
(
"graph mode must be in (eager, static)"
)
def
decorator
(
func
):
def
wrapper
(
*
args
,
**
kwargs
):
if
"eager"
in
set
(
modes
):
func
(
*
args
,
**
kwargs
)
if
"static"
in
set
(
modes
):
with
Graph
()
as
cg
:
cg
.
set_option
(
"eager_evaluation"
,
False
)
func
(
*
args
,
**
kwargs
)
return
wrapper
return
decorator
def
_default_compare_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
.
numpy
(),
y
,
rtol
=
1e-6
)
def
opr_test
(
cases
,
func
,
mode
=
(
"eager"
,
"static"
,
"dynamic_shape"
),
compare_fn
=
_default_compare_fn
,
ref_fn
=
None
,
**
kwargs
):
"""
mode: the list of test mode which are eager, static and dynamic_shape
will test all the cases if None.
func: the function to run opr.
compare_fn: the function to compare the result and expected, use np.testing.assert_allclose if None.
ref_fn: the function to generate expected data, should assign output if None.
cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input,
and should have output if ref_fn is None.
should use list for multiple inputs and outputs for each case.
kwargs: The additional kwargs for opr func.
simple examples:
dtype = np.float32
cases = [{"input": [10, 20]}, {"input": [20, 30]}]
opr_test(cases,
F.eye,
ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
dtype=dtype)
"""
def
check_results
(
results
,
expected
):
if
not
isinstance
(
results
,
Tuple
):
results
=
(
results
,)
for
r
,
e
in
zip
(
results
,
expected
):
compare_fn
(
r
,
e
)
def
get_trace_fn
(
func
,
enabled
,
symbolic
):
jit
.
trace
.
enabled
=
enabled
return
jit
.
trace
(
func
,
symbolic
=
symbolic
)
def
get_param
(
cases
,
idx
):
case
=
cases
[
idx
]
inp
=
case
.
get
(
"input"
,
None
)
outp
=
case
.
get
(
"output"
,
None
)
if
inp
is
None
:
raise
ValueError
(
"the test case should have input"
)
if
not
isinstance
(
inp
,
List
):
inp
=
(
inp
,)
else
:
inp
=
tuple
(
inp
)
if
ref_fn
is
not
None
and
callable
(
ref_fn
):
outp
=
ref_fn
(
*
inp
)
if
outp
is
None
:
raise
ValueError
(
"the test case should have output or reference function"
)
if
not
isinstance
(
outp
,
List
):
outp
=
(
outp
,)
else
:
outp
=
tuple
(
outp
)
return
inp
,
outp
if
not
set
(
mode
).
issubset
({
"eager"
,
"static"
,
"dynamic_shape"
}):
raise
ValueError
(
"opr test mode must be in (eager, static, dynamic_shape)"
)
if
len
(
cases
)
==
0
:
raise
ValueError
(
"should give one case at least"
)
if
"dynamic_shape"
in
set
(
mode
):
if
len
(
cases
)
!=
2
:
raise
ValueError
(
"should give 2 cases for dynamic shape test"
)
if
not
callable
(
func
):
raise
ValueError
(
"the input func should be callable"
)
inp
,
outp
=
get_param
(
cases
,
0
)
def
run
(
*
args
,
**
kwargs
):
return
func
(
*
args
,
**
kwargs
)
if
"eager"
in
set
(
mode
):
f
=
get_trace_fn
(
run
,
False
,
False
)
results
=
f
(
*
inp
,
**
kwargs
)
check_results
(
results
,
outp
)
if
"static"
in
set
(
mode
)
or
"dynamic_shape"
in
set
(
mode
):
f
=
get_trace_fn
(
run
,
True
,
True
)
results
=
f
(
*
inp
,
**
kwargs
)
check_results
(
results
,
outp
)
if
"dynamic_shape"
in
set
(
mode
):
inp
,
outp
=
get_param
(
cases
,
1
)
results
=
f
(
*
inp
,
**
kwargs
)
check_results
(
results
,
outp
)
class
MyModule
(
Module
):
class
InnerModule
(
Module
):
def
__init__
(
self
):
...
...
@@ -306,13 +166,13 @@ def test_module_api_hooks():
post_hook_num
=
0
hooks
=
[]
def
pre_hook
(
module
,
inputs
):
def
pre_hook
(
_
,
inputs
):
nonlocal
pre_hook_num
pre_hook_num
+=
1
modified_inputs
=
tuple
(
inp
+
1
for
inp
in
inputs
)
return
modified_inputs
def
post_hook
(
module
,
inputs
,
outputs
):
def
post_hook
(
_
,
__
,
outputs
):
nonlocal
post_hook_num
post_hook_num
+=
1
outputs
+=
1
...
...
@@ -376,7 +236,7 @@ class MyModule2(Module):
def
test_expand_structure
():
m
=
MyModule2
()
assert
list
(
m
.
named_modules
())
=
=
[
rst
=
[
(
""
,
m
),
(
"a.0"
,
m
.
a
[
0
]),
(
"a.1.x"
,
m
.
a
[
1
][
"x"
]),
...
...
@@ -387,6 +247,16 @@ def test_expand_structure():
(
"a.2.0.bn"
,
m
.
a
[
2
][
0
].
bn
),
(
"bn"
,
m
.
bn
),
]
assert
list
(
m
.
named_modules
())
==
rst
for
item
in
rst
[
1
:]:
assert
get_expand_structure
(
m
,
item
[
0
])
==
item
[
1
]
for
item
in
reversed
(
rst
[
1
:]):
if
_access_structure
(
m
,
item
[
0
],
lambda
p
,
k
,
o
:
isinstance
(
p
,
tuple
)):
continue
set_expand_structure
(
m
,
item
[
0
],
"TEST_VALUE"
)
assert
get_expand_structure
(
m
,
item
[
0
])
==
"TEST_VALUE"
def
test_flatten_others
():
...
...
@@ -603,21 +473,6 @@ def test_pickle_module():
np
.
testing
.
assert_allclose
(
pred0
.
numpy
(),
pred2
.
numpy
(),
atol
=
5e-6
)
@
pytest
.
mark
.
skip
(
reason
=
"under development"
)
def
test_dump_model
():
data_shape
=
(
2
,
28
)
data
=
Tensor
(
np
.
random
.
random
(
data_shape
))
mlp
=
MLP
()
pred
=
mlp
(
data
)
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
f_name
=
f
.
name
try
:
mge
.
dump
(
pred
,
f_name
)
finally
:
f
.
close
()
os
.
unlink
(
f_name
)
def
test_load_quantized
():
from
megengine.core.tensor
import
dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录