Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b13343f8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b13343f8
编写于
8月 21, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix quant bug
上级
8878f448
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
24 addition
and
13 deletion
+24
-13
mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc
mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc
+1
-1
mindspore/common/api.py
mindspore/common/api.py
+1
-1
mindspore/ops/operations/_inner_ops.py
mindspore/ops/operations/_inner_ops.py
+2
-0
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-0
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+18
-11
未找到文件。
mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc
浏览文件 @
b13343f8
...
...
@@ -29,7 +29,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor
// To-DO the format may read from ME tensor
MS_EXCEPTION_IF_NULL
(
value
);
auto
me_tensor
=
value
->
cast
<
MeTensorPtr
>
();
auto
ge_tensor
=
TransformUtil
::
ConvertTensor
(
me_tensor
,
kOpFormat_N
CHW
);
auto
ge_tensor
=
TransformUtil
::
ConvertTensor
(
me_tensor
,
kOpFormat_N
D
);
return
ge_tensor
==
nullptr
?
GeTensor
()
:
*
ge_tensor
;
}
...
...
mindspore/common/api.py
浏览文件 @
b13343f8
...
...
@@ -388,7 +388,7 @@ class _Executor:
dic
=
dict
(
zip
(
args_names
,
args_list
))
key
=
generate_key
(
phase
,
dic
)
self
.
phase_prefix
=
str
(
key
[
1
])
if
phase
==
'export'
:
if
'export'
in
phase
:
phase
=
phase
+
'.'
+
self
.
phase_prefix
+
'.'
+
str
(
obj
.
create_time
)
else
:
phase
=
self
.
phase_prefix
+
phase
+
'.'
+
str
(
obj
.
create_time
)
...
...
mindspore/ops/operations/_inner_ops.py
浏览文件 @
b13343f8
...
...
@@ -332,6 +332,7 @@ class Quant(PrimitiveWithInfer):
self
.
sqrt_mode
=
validator
.
check_value_type
(
"sqrt_mode"
,
sqrt_mode
,
[
bool
],
self
.
name
)
self
.
round_mode
=
validator
.
check_string
(
"round_mode"
,
round_mode
,
[
"Round"
,
"Floor"
,
"Ceil"
,
"Trunc"
],
self
.
name
)
self
.
add_prim_attr
(
"io_format"
,
"ND"
)
def
infer_shape
(
self
,
x_shape
):
return
x_shape
...
...
@@ -382,6 +383,7 @@ class Dequant(PrimitiveWithInfer):
self
.
sqrt_mode
=
validator
.
check_value_type
(
"sqrt_mode"
,
sqrt_mode
,
[
bool
],
self
.
name
)
self
.
relu_flag
=
validator
.
check_value_type
(
"relu_flag"
,
relu_flag
,
[
bool
],
self
.
name
)
self
.
add_prim_attr
(
"dtype"
,
mstype
.
float16
)
self
.
add_prim_attr
(
"io_format"
,
"ND"
)
def
infer_shape
(
self
,
x_shape
,
deq_scale_shape
):
return
x_shape
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
b13343f8
...
...
@@ -258,6 +258,7 @@ class _Reduce(PrimitiveWithInfer):
"""init Reduce"""
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'input_x'
,
'axis'
],
outputs
=
[
'y'
])
self
.
add_prim_attr
(
"io_format"
,
"ND"
)
def
__call__
(
self
,
x
,
axis
=
()):
args
=
[
x
,
axis
]
...
...
@@ -626,6 +627,7 @@ class MatMul(PrimitiveWithInfer):
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
self
.
add_prim_attr
(
"io_format"
,
"ND"
)
def
check_shape_size
(
self
,
x
,
y
):
if
len
(
x
)
!=
2
or
len
(
y
)
!=
2
:
...
...
mindspore/train/quant/quant.py
浏览文件 @
b13343f8
...
...
@@ -314,8 +314,8 @@ class ExportToQuantInferNetwork:
network
=
validator
.
check_isinstance
(
'network'
,
network
,
(
nn
.
Cell
,))
# quantize for inputs: q = f / scale + zero_point
# dequantize for outputs: f = (q - zero_point) * scale
self
.
input_scale
=
round
(
mean
)
self
.
input_zero_point
=
1
/
std_dev
self
.
input_scale
=
1
/
std_dev
self
.
input_zero_point
=
round
(
mean
)
self
.
data_type
=
mstype
.
int8
self
.
network
=
copy
.
deepcopy
(
network
)
self
.
all_parameters
=
{
p
.
name
:
p
for
p
in
self
.
network
.
get_parameters
()}
...
...
@@ -351,20 +351,16 @@ class ExportToQuantInferNetwork:
else
:
maxq
=
self
.
all_parameters
[
minq_name
[:
-
4
]
+
"maxq"
]
minq
=
self
.
all_parameters
[
minq_name
]
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
m
axq
,
min
q
,
np_type
)
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
m
inq
,
max
q
,
np_type
)
else
:
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fake_quant.minq`
{
w_minq_name
}
"
)
return
None
# Build the `Quant` `Dequant` op.
# Quant only support perlayer version. Need check here.
quant_op
=
inner
.
Quant
(
float
(
scale_a_in
),
float
(
zp_a_in
))
sqrt_mode
=
False
quant_op
=
inner
.
Quant
(
1
/
float
(
scale_a_in
),
float
(
zp_a_in
))
scale_deq
=
scale_a_out
*
scale_w
if
(
scale_deq
<
2
**
-
14
).
all
():
scale_deq
=
np
.
sqrt
(
scale_deq
)
sqrt_mode
=
True
dequant_op
=
inner
.
Dequant
(
sqrt_mode
)
dequant_op
=
inner
.
Dequant
()
if
isinstance
(
activation
,
_AddFakeQuantAfterSubCell
):
activation
=
activation
.
subcell
...
...
@@ -385,8 +381,19 @@ class ExportToQuantInferNetwork:
# apply the quant
weight
=
quant_utils
.
weight2int
(
weight
,
scale_w
,
zp_w
)
if
bias
is
not
None
:
bias
=
Tensor
(
scale_a_in
*
scale_w
*
bias
,
mstype
.
int32
)
scale_deq
=
Tensor
(
scale_deq
,
mstype
.
float16
)
bias
=
Tensor
(
bias
/
scale_a_in
/
scale_w
,
mstype
.
int32
)
# fuse parameter
# |--------|47:40|--------|39:32|--------|31:0|
# offset_w [8] shift_N [8] deq_scale [32]
float32_deq_scale
=
scale_deq
.
astype
(
np
.
float32
)
uint32_deq_scale
=
np
.
frombuffer
(
float32_deq_scale
,
np
.
uint32
)
scale_length
=
scale_deq
.
size
# channel
dequant_param
=
np
.
zeros
(
scale_length
,
dtype
=
np
.
uint64
)
for
index
in
range
(
scale_length
):
dequant_param
[
index
]
+=
uint32_deq_scale
[
index
]
scale_deq
=
Tensor
(
dequant_param
,
mstype
.
uint64
)
# get op
if
isinstance
(
cell_core
,
quant
.
DenseQuant
):
op_core
=
P
.
MatMul
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录