Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1fc91411
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1fc91411
编写于
8月 08, 2020
作者:
X
xutianchun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Post Training Quantization
上级
0df5a561
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
201 addition
and
99 deletion
+201
-99
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+25
-12
mindspore/lite/src/lite_session.cc
mindspore/lite/src/lite_session.cc
+10
-0
mindspore/lite/src/ops/quant_dtype_cast.cc
mindspore/lite/src/ops/quant_dtype_cast.cc
+1
-0
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
...pore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
+1
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.cc
...ite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.cc
+9
-2
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
...ols/converter/legacy_optimizer/node/weight_format_pass.cc
+11
-9
mindspore/lite/tools/converter/quantizer/post_training.cc
mindspore/lite/tools/converter/quantizer/post_training.cc
+27
-8
mindspore/lite/tools/converter/quantizer/post_training.h
mindspore/lite/tools/converter/quantizer/post_training.h
+6
-3
mindspore/lite/tools/converter/quantizer/quant_cast.cc
mindspore/lite/tools/converter/quantizer/quant_cast.cc
+6
-5
mindspore/lite/tools/converter/quantizer/quantize_util.cc
mindspore/lite/tools/converter/quantizer/quantize_util.cc
+93
-35
mindspore/lite/tools/converter/quantizer/quantize_util.h
mindspore/lite/tools/converter/quantizer/quantize_util.h
+12
-23
未找到文件。
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
浏览文件 @
1fc91411
...
...
@@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
if
(
node
->
quantType
==
schema
::
QuantType_PostTraining
)
{
MS_LOG
(
INFO
)
<<
"node: "
<<
node
->
name
<<
" add QuantParam"
;
// activation
auto
activate_index
=
node
->
inputIndex
[
0
];
auto
tensor_input
=
metaGraphT
->
allTensors
[
activate_index
].
get
();
auto
input_quant_params
=
primitiveT_value
->
GetInputQuantParams
();
if
(
input_quant_params
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"node: "
<<
node
->
name
<<
" input quant params is empty"
;
}
else
{
auto
node_type
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
type
;
for
(
int
i
=
0
;
i
<
input_quant_params
.
size
();
i
++
)
{
if
(
i
>=
node
->
inputIndex
.
size
())
{
MS_LOG
(
ERROR
)
<<
"node: "
<<
node
->
name
<<
" input has "
<<
input_quant_params
.
size
()
<<
" quant_params; but only "
<<
node
->
inputIndex
.
size
()
<<
" input"
;
break
;
}
auto
activate_index
=
node
->
inputIndex
[
i
];
auto
tensor_input
=
metaGraphT
->
allTensors
[
activate_index
].
get
();
std
::
unique_ptr
<
schema
::
QuantParamT
>
input_quant_param
=
std
::
make_unique
<
schema
::
QuantParamT
>
(
input_quant_params
[
0
]);
std
::
make_unique
<
schema
::
QuantParamT
>
(
input_quant_params
[
i
]);
MS_LOG
(
DEBUG
)
<<
"[input]node: "
<<
node
->
name
<<
" scale: "
<<
input_quant_param
->
scale
<<
" zp: "
<<
input_quant_param
->
zeroPoint
;
tensor_input
->
quantParams
.
emplace_back
(
std
::
move
(
input_quant_param
));
}
if
(
!
(
node_type
==
schema
::
PrimitiveType_QuantDTypeCast
&&
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsQuantDTypeCast
()
->
srcT
==
kNumberTypeFloat32
))
{
tensor_input
->
dataType
=
kNumberTypeInt8
;
}
}
// output
auto
output_index
=
node
->
outputIndex
[
0
];
auto
tensor_output
=
metaGraphT
->
allTensors
[
output_index
].
get
();
auto
output_quant_params
=
primitiveT_value
->
GetOutputQuantParams
();
if
(
output_quant_params
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"node: "
<<
node
->
name
<<
" output quant params is empty"
;
MS_LOG
(
WARNING
)
<<
"node: "
<<
node
->
name
<<
" output quant params is empty"
;
}
else
{
std
::
unique_ptr
<
schema
::
QuantParamT
>
output_quant_param
=
std
::
make_unique
<
schema
::
QuantParamT
>
(
output_quant_params
[
0
]);
MS_LOG
(
DEBUG
)
<<
"[output]node: "
<<
node
->
name
<<
" scale: "
<<
output_quant_param
->
scale
<<
" zp: "
<<
output_quant_param
->
zeroPoint
;
tensor_output
->
quantParams
.
emplace_back
(
std
::
move
(
output_quant_param
));
}
if
(
!
(
node_type
==
schema
::
PrimitiveType_QuantDTypeCast
&&
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsQuantDTypeCast
()
->
dstT
==
kNumberTypeFloat32
))
{
tensor_output
->
dataType
=
kNumberTypeInt8
;
}
// // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) {
...
...
mindspore/lite/src/lite_session.cc
浏览文件 @
1fc91411
...
...
@@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
// no copy data, do copy when call LiteKernel::Init
dstTensor
->
SetData
(
const_cast
<
unsigned
char
*>
(
srcTensor
->
data
()
->
data
()));
}
auto
quant_params
=
srcTensor
->
quantParams
();
if
(
quant_params
!=
nullptr
)
{
for
(
int
j
=
0
;
j
<
quant_params
->
size
();
j
++
)
{
tensor
::
QuantArg
quant_arg
{};
quant_arg
.
scale
=
quant_params
->
Get
(
j
)
->
scale
();
quant_arg
.
zeroPoint
=
quant_params
->
Get
(
j
)
->
zeroPoint
();
dstTensor
->
AddQuantParam
(
quant_arg
);
}
}
this
->
tensors
.
emplace_back
(
dstTensor
);
}
return
RET_OK
;
...
...
mindspore/lite/src/ops/quant_dtype_cast.cc
浏览文件 @
1fc91411
...
...
@@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
auto
param
=
primitive
->
value_as_QuantDTypeCast
();
MS_ASSERT
(
input
->
data_type
()
==
param
->
srcT
);
output
->
set_data_type
(
static_cast
<
TypeId
>
(
param
->
dstT
()));
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
浏览文件 @
1fc91411
...
...
@@ -58,7 +58,7 @@ int QuantDTypeCastCPUKernel::Init() {
}
inverse_
=
true
;
}
else
{
MS_LOG
(
ERROR
)
<<
"param data type not supported
."
;
MS_LOG
(
ERROR
)
<<
"param data type not supported
:"
<<
" src: "
<<
param
->
srcT
<<
" dst: "
<<
param
->
dstT
;
return
RET_ERROR
;
}
...
...
@@ -143,7 +143,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::t
}
return
kernel
;
}
REG_KERNEL
(
kCPU
,
kNumberTypeInt8
,
PrimitiveType_QuantDTypeCast
,
CpuQuantDTypeCastFp32KernelCreator
)
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_QuantDTypeCast
,
CpuQuantDTypeCastFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.cc
浏览文件 @
1fc91411
...
...
@@ -23,7 +23,7 @@ int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_
}
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
real_values
[
i
]
=
(
quant_values
[
i
]
+
zp
)
*
scale
;
real_values
[
i
]
=
(
quant_values
[
i
]
-
zp
)
*
scale
;
}
return
NNACL_OK
;
}
...
...
@@ -34,7 +34,14 @@ int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_
}
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
quant_values
[
i
]
=
(
int8_t
)
round
(
real_values
[
i
]
/
scale
+
zp
);
float
temp
=
round
(
real_values
[
i
]
/
scale
+
zp
);
if
(
temp
>
127
)
{
quant_values
[
i
]
=
127
;
}
else
if
(
temp
<
-
128
)
{
quant_values
[
i
]
=
-
128
;
}
else
{
quant_values
[
i
]
=
(
int8_t
)
temp
;
}
}
return
NNACL_OK
;
}
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
浏览文件 @
1fc91411
...
...
@@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return
-
1
;
}
}
MS_LOG
(
DEBUG
)
<<
"weight_tensor_format: "
<<
weightTensor
->
format
;
return
0
;
}
else
if
(
fmkType
==
converter
::
FmkType_ONNX
)
{
switch
(
node
->
quantType
)
{
...
...
@@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto
opType
=
node
->
primitive
->
value
.
type
;
if
(
opType
!=
schema
::
PrimitiveType_Conv2D
&&
opType
!=
schema
::
PrimitiveType_DepthwiseConv2D
&&
opType
!=
schema
::
PrimitiveType_DeConv2D
&&
opType
!=
schema
::
PrimitiveType_DeDepthwiseConv2D
)
{
return
0
;
return
RET_OK
;
}
MS_ASSERT
(
node
->
inputIndex
.
size
()
>=
2
);
...
...
@@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT
(
subGraph
->
allTensors
.
size
()
>
weightIndex
);
auto
&
weightTensor
=
subGraph
->
allTensors
[
weightIndex
];
MS_ASSERT
(
weightTensor
->
dataType
==
kNumberTypeInt8
);
// DataType_DT_FLOAT
STATUS
status
;
STATUS
status
=
RET_OK
;
if
(
opType
==
schema
::
PrimitiveType_Conv2D
)
{
// weight should be HWCK
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
// from caffe
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
...
...
@@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
}
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KHWC
)
{
// from onnx
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
int8_t
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
}
return
RET_OK
;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
}
else
if
(
weightTensor
->
format
==
schema
::
Format_HWCK
)
{
// from tf
return
0
;
}
else
{
...
...
@@ -274,7 +276,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
return
0
;
}
else
if
(
weightTensor
->
format
==
schema
::
Format_CHWK
)
{
// from onnx
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
status
=
TransFilterFormat
<
int8_t
>
(
weightTensor
.
get
(),
kCHWK2KHWC
);
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
}
...
...
mindspore/lite/tools/converter/quantizer/post_training.cc
浏览文件 @
1fc91411
...
...
@@ -54,7 +54,7 @@ struct DivergInfo {
size_t
bit_num
;
int
quant_max
=
255
;
int
quant_min
=
0
;
DivergInfo
(
CNodePtr
cnode
,
int
bins
,
size_t
bits
,
int
quant_max
=
255
,
int
quant_min
=
0
)
{
DivergInfo
(
CNodePtr
cnode
,
int
bins
,
size_t
bits
,
int
quant_max
,
int
quant_min
)
{
this
->
cnode
=
cnode
;
this
->
bin_num
=
bins
;
this
->
bit_num
=
bits
;
...
...
@@ -81,6 +81,9 @@ struct DivergInfo {
STATUS
UpdateHistogram
(
const
std
::
vector
<
float
>
&
data
,
const
std
::
vector
<
int
>
&
shape
)
{
for
(
auto
value
:
data
)
{
if
(
value
==
0
)
{
continue
;
}
int
bin_index
=
std
::
min
(
static_cast
<
int
>
(
std
::
fabs
(
value
)
/
this
->
interval
),
bin_num
-
1
);
this
->
histogram
[
bin_index
]
++
;
}
...
...
@@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() {
Calibrator
::
Calibrator
(
string
path
,
size_t
bitNum
,
int
quantMax
,
int
quantMin
)
:
config_path_
(
path
),
bit_num_
(
bitNum
),
quant_max_
(
quantMax
),
quant_min_
(
quantMin
)
{}
PostTrainingQuantizer
::
PostTrainingQuantizer
(
FuncGraphPtr
graph
,
string
path
,
int
bit_num
,
TypeId
target_type
)
PostTrainingQuantizer
::
PostTrainingQuantizer
(
FuncGraphPtr
graph
,
string
path
,
int
bit_num
,
TypeId
target_type
,
bool
per_channel
)
:
Quantizer
(
graph
)
{
this
->
per_channel_
=
per_channel
;
this
->
bit_num
=
bit_num
;
this
->
target_type_
=
target_type
;
if
(
target_type
==
kNumberTypeInt8
)
{
...
...
@@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
}
auto
parameter
=
std
::
dynamic_pointer_cast
<
Parameter
>
(
node
);
ParamValueLitePtr
paramValue
=
std
::
dynamic_pointer_cast
<
ParamValueLite
>
(
parameter
->
default_param
());
auto
status
=
QuantFilter
(
paramValue
,
QuantType_PostTraining
,
quant_max
,
quant_min
,
bit_num
);
auto
status
=
QuantFilter
(
paramValue
,
QuantType_PostTraining
,
quant_max
,
quant_min
,
bit_num
,
per_channel_
);
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"QuantFilter failed: "
<<
status
;
return
status
;
...
...
@@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() {
MS_LOG
(
ERROR
)
<<
"PrimitiveT_value is nullptr"
;
continue
;
}
if
(
input_scale
.
find
(
cnode
)
==
input_scale
.
end
())
{
primitiveT_value
->
SetQuantType
(
schema
::
QuantType_QUANT_NONE
);
continue
;
}
auto
input_vec
=
cnode
->
inputs
();
auto
op_name
=
cnode
->
fullname_with_scope
();
auto
op_type
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
type
;
MS_LOG
(
INFO
)
<<
"OpName: "
<<
op_name
;
if
(
input_vec
.
size
()
<=
3
&&
op_name
!=
"Conv2D"
&&
op_name
!=
"DepthwiseConv2D"
)
{
MS_LOG
(
INFO
)
<<
"todo(x): "
;
// int32_t qnodeOutputZeropoint = outputZeropoint[cnode];
// p->AddAttr(kInputTensorDataType, MakeValue((int)targetType));
if
(
op_type
!=
PrimitiveType_Conv2D
&&
op_type
!=
PrimitiveType_DepthwiseConv2D
)
{
for
(
auto
i
=
1
;
i
<
cnode
->
inputs
().
size
();
i
++
)
{
auto
input_node
=
cnode
->
input
(
i
);
if
(
!
input_node
->
isa
<
mindspore
::
CNode
>
())
{
MS_LOG
(
WARNING
)
<<
"node: "
<<
cnode_name
<<
" input "
<<
i
<<
" not a cnode"
;
continue
;
}
auto
input_cnode
=
std
::
dynamic_pointer_cast
<
mindspore
::
CNode
>
(
input_node
);
auto
input_cnode_primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveTValue
>>
(
input_cnode
->
input
(
0
));
if
(
input_cnode_primitiveT_value
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"input: "
<<
i
<<
" "
<<
input_cnode
->
fullname_with_scope
()
<<
": "
<<
" PrimitiveTValue is null"
;
continue
;
}
for
(
auto
&
quant_param
:
input_cnode_primitiveT_value
->
GetOutputQuantParams
())
{
primitiveT_value
->
AddInputQuantParam
(
quant_param
);
}
}
}
else
{
// do input quant
double
scale
=
input_scale
[
cnode
];
...
...
mindspore/lite/tools/converter/quantizer/post_training.h
浏览文件 @
1fc91411
...
...
@@ -55,15 +55,18 @@ struct ConfigParam {
class
PostTrainingQuantizer
:
public
Quantizer
{
public:
PostTrainingQuantizer
(
FuncGraphPtr
graph
,
std
::
string
path
,
int
bit_num
,
TypeId
target_type
=
kNumberTypeInt8
);
PostTrainingQuantizer
(
FuncGraphPtr
graph
,
std
::
string
path
,
int
bit_num
,
TypeId
target_type
=
kNumberTypeInt8
,
bool
per_channel
=
false
);
STATUS
DoQuantize
(
FuncGraphPtr
funcGraph
)
override
;
size_t
bit_num
;
int
quant_max
{
255
};
int
quant_min
{
0
};
int
quant_max
{
127
};
int
quant_min
{
-
128
};
private:
bool
per_channel_
;
TypeId
target_type_
{
kNumberTypeInt8
};
std
::
unique_ptr
<
Calibrator
>
calibrator_
;
...
...
mindspore/lite/tools/converter/quantizer/quant_cast.cc
浏览文件 @
1fc91411
...
...
@@ -25,10 +25,11 @@ namespace mindspore::lite::quant {
ValueNodePtr
NewQuantCastValueNode
(
int
src_type
,
int
dst_type
,
const
std
::
vector
<
schema
::
QuantParamT
>
&
quant_params
)
{
std
::
unique_ptr
<
schema
::
PrimitiveT
>
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
schema
::
QuantDTypeCastT
quant_dtype_cast
;
quant_dtype_cast
.
srcT
=
src_type
;
// kNumberType
U
Int8;
quant_dtype_cast
.
srcT
=
src_type
;
// kNumberTypeInt8;
quant_dtype_cast
.
dstT
=
dst_type
;
// kNumberTypeFloat32;
primitive
->
value
.
Set
(
quant_dtype_cast
);
auto
primTValue
=
std
::
make_shared
<
PrimitiveTValue
>
(
primitive
.
release
());
primTValue
->
SetQuantType
(
schema
::
QuantType_PostTraining
);
for
(
auto
&
quant_param
:
quant_params
)
{
primTValue
->
AddInputQuantParam
(
quant_param
);
}
...
...
@@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if
(
first
)
{
if
(
curnode_quant_type
==
schema
::
QuantType_PostTraining
&&
inputDataDType
==
kNumberTypeFloat32
)
{
auto
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberType
U
Int8
,
primitiveT_value
->
GetInputQuantParams
());
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberTypeInt8
,
primitiveT_value
->
GetInputQuantParams
());
std
::
vector
<
AnfNodePtr
>
op_inputs
=
{
value_node
,
cnode
->
input
(
1
)};
auto
quant_cast_cnode
=
graph
->
NewCNode
(
op_inputs
);
quant_cast_cnode
->
set_fullname_with_scope
(
cnode
->
fullname_with_scope
()
+
"_quant_cast"
);
...
...
@@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
ValueNodePtr
value_node
=
nullptr
;
if
(
curnode_quant_type
==
schema
::
QuantType_PostTraining
&&
input_cnode_quant_type
==
schema
::
QuantType_QUANT_NONE
)
{
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberType
U
Int8
,
input_cnode_
primitiveT_value
->
GetInputQuantParams
());
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberTypeInt8
,
primitiveT_value
->
GetInputQuantParams
());
}
else
if
(
curnode_quant_type
==
schema
::
QuantType_QUANT_NONE
&&
input_cnode_quant_type
==
schema
::
QuantType_PostTraining
)
{
value_node
=
NewQuantCastValueNode
(
kNumberType
U
Int8
,
kNumberTypeFloat32
,
value_node
=
NewQuantCastValueNode
(
kNumberTypeInt8
,
kNumberTypeFloat32
,
input_cnode_primitiveT_value
->
GetInputQuantParams
());
}
if
(
value_node
==
nullptr
)
{
...
...
mindspore/lite/tools/converter/quantizer/quantize_util.cc
浏览文件 @
1fc91411
...
...
@@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
static
const
std
::
vector
<
schema
::
PrimitiveType
>
uint8OpList
=
{
schema
::
PrimitiveType_Nchw2Nhwc
,
schema
::
PrimitiveType_Nhwc2Nchw
,
schema
::
PrimitiveType_Conv2D
,
schema
::
PrimitiveType_DepthwiseConv2D
,
schema
::
PrimitiveType_Add
,
schema
::
PrimitiveType_Pooling
,
schema
::
PrimitiveType_Concat
,
schema
::
PrimitiveType_SoftMax
,
schema
::
PrimitiveType_Reshape
,
schema
::
PrimitiveType_Concat
,
/*schema::PrimitiveType_SoftMax,*/
schema
::
PrimitiveType_Reshape
,
schema
::
PrimitiveType_Activation
};
return
IsContain
(
uint8OpList
,
type
);
}
...
...
@@ -242,13 +242,17 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
return
RET_OK
;
}
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
,
int
quant_max
,
int
quant_min
,
size_t
bitNum
)
{
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
,
int
quant_max
,
int
quant_min
,
size_t
bitNum
,
bool
per_channel
)
{
if
(
per_channel
)
{
// per channel
auto
dims
=
weightPtr
->
tensor_shape
();
if
(
dims
.
size
()
<
1
)
{
MS_LOG
(
ERROR
)
<<
"weight dims size error"
;
return
RET_ERROR
;
}
uint32_t
channels
=
dims
[
0
];
// todo(x)
uint32_t
channels
=
dims
[
3
];
if
(
channels
==
0
)
{
MS_LOG
(
ERROR
)
<<
"channels error 0"
;
return
RET_ERROR
;
...
...
@@ -263,7 +267,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
}
weightPtr
->
quant_param
().
clear
();
vector
<
u
int8_t
>
qDatas
(
shapeSize
);
vector
<
int8_t
>
qDatas
(
shapeSize
);
for
(
uint32_t
i
=
0
;
i
<
channels
;
i
++
)
{
float
min
=
0
;
float
max
=
0
;
...
...
@@ -282,14 +286,14 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
// update data and datatype
for
(
uint32_t
j
=
0
;
j
<
oneFilterSize
;
j
++
)
{
float
rawData
=
rawDatas
[
j
+
i
*
oneFilterSize
];
auto
qData
=
QuantizeData
<
uint8_t
>
(
rawData
,
quantParam
.
get
()
);
auto
qData
=
QuantizeData
<
int8_t
>
(
rawData
,
quantParam
.
get
(),
quant_max
,
quant_min
);
qDatas
[
j
+
i
*
oneFilterSize
]
=
qData
;
}
weightPtr
->
set_quant_param
(
quantParam
);
}
auto
ret
=
memcpy_s
(
const_cast
<
float
*>
(
rawDatas
),
weightPtr
->
tensor_size
(),
qDatas
.
data
(),
shapeSize
*
sizeof
(
u
int8_t
));
qDatas
.
data
(),
shapeSize
*
sizeof
(
int8_t
));
if
(
ret
!=
EOK
)
{
MS_LOG
(
ERROR
)
<<
"memcpy error: "
<<
ret
;
return
RET_ERROR
;
...
...
@@ -300,6 +304,60 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
weightPtr
->
set_tensor_type
(
kNumberTypeInt8
);
weightPtr
->
set_tensor_size
(
shapeSize
*
sizeof
(
int8_t
));
}
else
{
// per layer
size_t
shapeSize
=
weightPtr
->
tensor_shape_size
();
auto
*
rawDatas
=
static_cast
<
float
*>
(
weightPtr
->
tensor_addr
());
if
(
rawDatas
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"rawDatas is nullptr"
;
return
RET_ERROR
;
}
weightPtr
->
quant_param
().
clear
();
vector
<
int8_t
>
qDatas
(
shapeSize
);
float
min
=
0
;
float
max
=
0
;
for
(
uint32_t
i
=
0
;
i
<
shapeSize
;
i
++
)
{
// find max min
min
=
std
::
min
(
min
,
rawDatas
[
i
]);
max
=
std
::
max
(
max
,
rawDatas
[
i
]);
}
std
::
unique_ptr
<
AnfQuantParam
>
quantParam
=
std
::
unique_ptr
<
AnfQuantParam
>
(
new
AnfQuantParam
);
STATUS
status
=
CalQuantizationParams
(
quantParam
,
min
,
max
,
false
,
quant_max
,
quant_min
,
bitNum
);
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"CalQuantizationParams failed"
<<
status
;
return
status
;
}
// update data and datatype
for
(
uint32_t
i
=
0
;
i
<
shapeSize
;
i
++
)
{
float
rawData
=
rawDatas
[
i
];
auto
quant_data
=
std
::
round
(
rawData
/
quantParam
->
scale
+
quantParam
->
zeroPoint
);
if
(
quant_data
>
quant_max
)
{
qDatas
[
i
]
=
quant_max
;
}
else
if
(
quant_data
<
quant_min
)
{
qDatas
[
i
]
=
quant_min
;
}
else
{
qDatas
[
i
]
=
static_cast
<
int8_t
>
(
quant_data
);
}
}
weightPtr
->
set_quant_param
(
quantParam
);
auto
ret
=
memcpy_s
(
rawDatas
,
weightPtr
->
tensor_size
()
*
sizeof
(
int8_t
),
qDatas
.
data
(),
shapeSize
*
sizeof
(
int8_t
));
if
(
ret
!=
EOK
)
{
MS_LOG
(
ERROR
)
<<
"memcpy error: "
<<
ret
;
return
RET_ERROR
;
}
if
(
quantType
==
QuantType_WeightQuant
)
{
PostBitPack
(
rawDatas
,
shapeSize
,
bitNum
);
}
weightPtr
->
set_tensor_type
(
kNumberTypeInt8
);
weightPtr
->
set_tensor_size
(
shapeSize
*
sizeof
(
int8_t
));
}
return
RET_OK
;
}
...
...
mindspore/lite/tools/converter/quantizer/quantize_util.h
浏览文件 @
1fc91411
...
...
@@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
bool
narrowRange
,
int
quant_max
,
int
quant_min
,
int
num_bits
);
template
<
typename
T
>
T
QuantizeData
(
const
float
originData
,
const
AnfQuantParam
*
quantParam
)
{
T
QuantizeData
(
float
originData
,
const
AnfQuantParam
*
quantParam
,
int
quant_max
,
int
quant_min
)
{
MS_ASSERT
(
quantParam
!=
nullptr
);
MS_ASSERT
(
quantParam
->
inited
);
const
auto
scale
=
quantParam
->
scale
;
const
auto
zeroPoint
=
quantParam
->
zeroPoint
;
const
auto
numBit
=
quantParam
->
numBits
;
const
int
zeroPoint
=
quantParam
->
zeroPoint
;
const
auto
narrowRange
=
quantParam
->
narrowRange
;
const
double
maxLimit
=
static_cast
<
float
>
((
1
<<
(
unsigned
int
)
numBit
)
-
1
-
zeroPoint
)
*
scale
;
double
minLimit
;
if
(
narrowRange
)
{
minLimit
=
static_cast
<
float
>
(
1
-
zeroPoint
)
*
scale
;
}
else
{
minLimit
=
static_cast
<
float
>
(
0
-
zeroPoint
)
*
scale
;
}
const
int
maxLimit
=
quant_max
;
const
int
minLimit
=
quant_min
;
return
[
maxLimit
,
minLimit
,
zeroPoint
,
scale
,
narrowRange
,
originData
]
{
double
tmp
=
0.0
f
;
if
(
originData
>
maxLimit
)
{
tmp
=
maxLimit
;
}
else
if
(
originData
<
minLimit
)
{
tmp
=
minLimit
;
}
else
{
tmp
=
originData
;
}
auto
quantData
=
static_cast
<
T
>
(
std
::
round
(
tmp
/
scale
+
zeroPoint
));
if
(
quantData
==
0
&&
narrowRange
)
{
quantData
++
;
int
quant_data
=
std
::
round
(
originData
/
scale
+
zeroPoint
);
if
(
quant_data
>
maxLimit
)
{
quant_data
=
maxLimit
;
}
else
if
(
quant_data
<
minLimit
)
{
quant_data
=
minLimit
;
}
return
quantData
;
return
static_cast
<
T
>
(
quant_data
)
;
}();
}
void
CalFakeNode
(
const
AnfNodePtr
&
inTensor
);
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
,
int
quant_max
,
int
quant_min
,
size_t
bitNum
=
UINT8_QUANTIZATION
);
size_t
bitNum
=
UINT8_QUANTIZATION
,
bool
per_channel
=
false
);
STATUS
PostBitPack
(
float
*
weights
,
size_t
shapeSize
,
size_t
bitNum
=
UINT8_QUANTIZATION
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录