Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
caf77d00
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
caf77d00
编写于
7月 02, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/module): add quantize dtype load support for module load_state_dict
GitOrigin-RevId: 0a94cb6b17005dd5e6a81c6a4b8c51c7044d2751
上级
dedb7a3f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
96 addition
and
0 deletion
+96
-0
python_module/megengine/core/tensor.py
python_module/megengine/core/tensor.py
+8
-0
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+5
-0
python_module/test/unit/core/test_tensor.py
python_module/test/unit/core/test_tensor.py
+46
-0
python_module/test/unit/module/test_module.py
python_module/test/unit/module/test_module.py
+37
-0
未找到文件。
python_module/megengine/core/tensor.py
浏览文件 @
caf77d00
...
...
@@ -235,6 +235,14 @@ class Tensor:
return
self
.
__val
.
dtype
return
self
.
_symvar
.
dtype
def
set_dtype
(
self
,
dtype
:
str
=
None
):
r
"""Set the data type of the tensor.
"""
if
self
.
__val
is
not
None
:
self
.
__val
=
mgb
.
make_shared
(
self
.
device
,
value
=
self
.
astype
(
dtype
).
numpy
())
elif
self
.
__sym
is
not
None
:
self
.
__sym
=
self
.
__sym
.
astype
(
dtype
)
@
property
def
_comp_node
(
self
):
if
self
.
__val
is
not
None
:
...
...
python_module/megengine/module/module.py
浏览文件 @
caf77d00
...
...
@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import
numpy
as
np
from
.._internal.dtype
import
is_quantize
from
..core
import
Buffer
,
Parameter
,
Tensor
from
..logger
import
get_logger
...
...
@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta):
),
"param `{}` shape mismatch, should be {}, get {}"
.
format
(
k
,
var
.
shape
,
to_be_load
.
shape
)
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if
is_quantize
(
to_be_load
.
dtype
)
and
is_quantize
(
var
.
dtype
):
var
.
set_dtype
(
to_be_load
.
dtype
)
var
.
set_value
(
to_be_load
)
loaded
.
append
(
k
)
...
...
python_module/test/unit/core/test_tensor.py
浏览文件 @
caf77d00
...
...
@@ -10,6 +10,7 @@ import numpy as np
import
pytest
import
megengine
as
mge
import
megengine._internal
as
mgb
def
test_wrong_dtype
():
...
...
@@ -26,3 +27,48 @@ def test_tensor_routine():
mge
.
tensor
([
1
])
mge
.
tensor
(
1.5
)
def
test_tensor_set_dtype
():
def
check_dtype_value
(
tensor
,
dtype_scale
,
value
):
if
mgb
.
dtype
.
is_quantize
(
tensor
.
dtype
):
if
np
.
abs
(
mgb
.
dtype
.
get_scale
(
tensor
.
dtype
)
-
dtype_scale
)
>
1e-5
:
raise
AssertionError
(
"compare scale failed expect {} got {}"
.
format
(
dtype_scale
,
mgb
.
dtype
.
get_scale
(
tensor
.
dtype
)
)
)
if
np
.
abs
(
tensor
.
numpy
()[
0
][
0
]
-
value
)
>
1e-5
:
raise
AssertionError
(
"compare value failed expect {} got {}"
.
format
(
tensor
.
numpy
()[
0
][
0
],
value
)
)
t
=
mge
.
Parameter
(
np
.
ones
((
3
,
4
),
dtype
=
"float32"
))
t
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.1
))
check_dtype_value
(
t
,
0.1
,
10
)
t
=
mge
.
Parameter
(
np
.
ones
((
3
,
4
),
dtype
=
mgb
.
dtype
.
qint8
(
1
)))
t
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.3
))
check_dtype_value
(
t
,
0.3
,
3
)
t
=
mge
.
Buffer
(
np
.
ones
((
3
,
4
),
dtype
=
"float32"
))
t
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.1
))
check_dtype_value
(
t
,
0.1
,
10
)
t
=
mge
.
Buffer
(
np
.
ones
((
3
,
4
),
dtype
=
mgb
.
dtype
.
qint8
(
1
)))
t
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.3
))
check_dtype_value
(
t
,
0.3
,
3
)
t
=
mge
.
Buffer
(
np
.
ones
((
3
,
4
),
dtype
=
"float32"
))
s
=
t
+
1
s
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.2
))
check_dtype_value
(
s
,
0.2
,
10
)
t
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.3
))
s
=
t
+
1
s
.
set_dtype
(
mgb
.
dtype
.
qint8
(
0.1
))
check_dtype_value
(
s
,
0.1
,
18
)
s
.
set_dtype
(
"float32"
)
check_dtype_value
(
s
,
0
,
1.8
)
python_module/test/unit/module/test_module.py
浏览文件 @
caf77d00
...
...
@@ -14,8 +14,10 @@ import pytest
from
helpers
import
MLP
import
megengine
as
mge
import
megengine._internal
as
mgb
from
megengine.core
import
Buffer
,
Parameter
,
Tensor
,
tensor
from
megengine.module
import
BatchNorm1d
,
BatchNorm2d
,
Conv2d
,
Module
,
Sequential
from
megengine.quantization.quantize
import
quantize
,
quantize_qat
from
megengine.test
import
assertTensorClose
...
...
@@ -347,3 +349,38 @@ def test_dump_model():
pred
=
mlp
(
data
)
with
tempfile
.
NamedTemporaryFile
()
as
f
:
mge
.
dump
(
pred
,
f
.
name
)
def
test_load_quantized
():
data_shape
=
(
2
,
28
)
data
=
tensor
(
np
.
random
.
random
(
data_shape
),
dtype
=
"float32"
)
data
=
data
.
astype
(
mgb
.
dtype
.
qint8
(
0.1
))
mlp
=
MLP
()
quantize_qat
(
mlp
)
quantize
(
mlp
)
mlp
.
dense0
.
weight
=
Parameter
(
mlp
.
dense0
.
weight
.
astype
(
mgb
.
dtype
.
qint8
(
0.001
)).
numpy
()
)
mlp
.
dense1
.
weight
=
Parameter
(
mlp
.
dense1
.
weight
.
astype
(
mgb
.
dtype
.
qint8
(
0.0002
)).
numpy
()
)
mlp
.
eval
()
pred0
=
mlp
(
data
)
with
BytesIO
()
as
fout
:
mge
.
save
(
mlp
.
state_dict
(),
fout
)
fout
.
seek
(
0
)
checkpoint
=
mge
.
load
(
fout
)
# change mlp weight.
mlp
.
dense0
.
weight
=
Parameter
(
mlp
.
dense0
.
weight
.
astype
(
mgb
.
dtype
.
qint8
(
0.00001
)).
numpy
()
)
mlp
.
dense1
.
weight
=
Parameter
(
mlp
.
dense1
.
weight
.
astype
(
mgb
.
dtype
.
qint8
(
0.2
)).
numpy
()
)
mlp
.
load_state_dict
(
checkpoint
)
pred1
=
mlp
(
data
)
assertTensorClose
(
pred0
.
astype
(
"float32"
).
numpy
(),
pred1
.
astype
(
"float32"
).
numpy
(),
max_err
=
5e-6
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录