Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
be511a56
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
be511a56
编写于
11月 03, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/imperative): fix tensor astype failed for quantized type
GitOrigin-RevId: 383458acbf18fa956ca1ccaa376255ff1b06735a
上级
b309890c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
32 addition
and
1 deletion
+32
-1
imperative/python/megengine/core/tensor/dtype.py
imperative/python/megengine/core/tensor/dtype.py
+15
-0
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+2
-1
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+15
-0
未找到文件。
imperative/python/megengine/core/tensor/dtype.py
浏览文件 @
be511a56
...
@@ -62,6 +62,21 @@ def get_zero_point(dtype):
...
@@ -62,6 +62,21 @@ def get_zero_point(dtype):
return
metadata
[
"zero_point"
]
return
metadata
[
"zero_point"
]
def
is_equal
(
dt0
,
dt1
):
def
_get_zero_point
(
dtype
):
assert
is_quantize
(
dtype
)
metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
return
metadata
.
get
(
"zero_point"
)
if
is_quantize
(
dt0
)
and
is_quantize
(
dt1
):
return
get_scale
(
dt0
)
==
get_scale
(
dt1
)
and
_get_zero_point
(
dt0
)
==
_get_zero_point
(
dt1
)
if
not
(
is_quantize
(
dt0
)
or
is_quantize
(
dt1
)):
return
dt0
==
dt1
return
False
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
qmin
=
_metadata_dict
[
dtype_str
].
qmin
qmin
=
_metadata_dict
[
dtype_str
].
qmin
qmax
=
_metadata_dict
[
dtype_str
].
qmax
qmax
=
_metadata_dict
[
dtype_str
].
qmax
...
...
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
be511a56
...
@@ -14,6 +14,7 @@ import numpy as np
...
@@ -14,6 +14,7 @@ import numpy as np
from
..ops
import
builtin
from
..ops
import
builtin
from
..ops.special
import
Const
from
..ops.special
import
Const
from
..tensor.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
from
..tensor.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
from
.dtype
import
is_equal
def
dtype_promotion
(
inputs
):
def
dtype_promotion
(
inputs
):
...
@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None):
...
@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None):
def
astype
(
x
,
dtype
):
def
astype
(
x
,
dtype
):
dtype
=
np
.
dtype
(
dtype
)
dtype
=
np
.
dtype
(
dtype
)
if
x
.
dtype
!=
dtype
:
if
not
is_equal
(
x
.
dtype
,
dtype
)
:
(
x
,)
=
apply
(
builtin
.
TypeCvt
(
param
=
dtype
),
x
)
(
x
,)
=
apply
(
builtin
.
TypeCvt
(
param
=
dtype
),
x
)
return
x
return
x
...
...
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
be511a56
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
numpy
as
np
from
megengine.core.tensor.dtype
import
get_scale
,
get_zero_point
,
qint8
,
quint8
from
megengine.core.tensor.tensor_wrapper
import
TensorWrapper
from
megengine.core.tensor.tensor_wrapper
import
TensorWrapper
...
@@ -71,3 +72,17 @@ def test_transpose():
...
@@ -71,3 +72,17 @@ def test_transpose():
x
=
np
.
random
.
rand
(
2
,
5
).
astype
(
"float32"
)
x
=
np
.
random
.
rand
(
2
,
5
).
astype
(
"float32"
)
xx
=
TensorWrapper
(
x
)
xx
=
TensorWrapper
(
x
)
np
.
testing
.
assert_almost_equal
(
xx
.
T
.
numpy
(),
x
.
T
)
np
.
testing
.
assert_almost_equal
(
xx
.
T
.
numpy
(),
x
.
T
)
def
test_as_type
():
x
=
TensorWrapper
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
y
=
x
.
astype
(
qint8
(
0.1
))
np
.
testing
.
assert_almost_equal
(
get_scale
(
y
.
dtype
),
0.1
)
z
=
y
.
astype
(
qint8
(
0.2
))
np
.
testing
.
assert_almost_equal
(
get_scale
(
z
.
dtype
),
0.2
)
a
=
z
.
astype
(
quint8
(
0.3
,
127
))
np
.
testing
.
assert_almost_equal
(
get_scale
(
a
.
dtype
),
0.3
)
np
.
testing
.
assert_equal
(
get_zero_point
(
a
.
dtype
),
127
)
b
=
a
.
astype
(
quint8
(
0.3
,
128
))
np
.
testing
.
assert_almost_equal
(
get_scale
(
b
.
dtype
),
0.3
)
np
.
testing
.
assert_equal
(
get_zero_point
(
b
.
dtype
),
128
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录