Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
28af1e50
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
28af1e50
编写于
8月 04, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 04, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3891 fix tflite reshape parser && post training quantization
Merge pull request !3891 from xutianchun/quant_0803
上级
3c717a97
8f334af0
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
100 addition
and
73 deletion
+100
-73
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+14
-11
mindspore/lite/src/model_impl.cc
mindspore/lite/src/model_impl.cc
+2
-0
mindspore/lite/src/ops/reshape.cc
mindspore/lite/src/ops/reshape.cc
+24
-13
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
...lite/tools/converter/optimizer/node/weight_format_pass.cc
+3
-3
mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc
...te/tools/converter/parser/tflite/tflite_reshape_parser.cc
+17
-8
mindspore/lite/tools/converter/quantizer/post_training.cc
mindspore/lite/tools/converter/quantizer/post_training.cc
+13
-6
mindspore/lite/tools/converter/quantizer/post_training.h
mindspore/lite/tools/converter/quantizer/post_training.h
+1
-1
mindspore/lite/tools/converter/quantizer/quant_cast.cc
mindspore/lite/tools/converter/quantizer/quant_cast.cc
+10
-4
mindspore/lite/tools/converter/quantizer/quantize_util.cc
mindspore/lite/tools/converter/quantizer/quantize_util.cc
+12
-23
mindspore/lite/tools/converter/quantizer/quantize_util.h
mindspore/lite/tools/converter/quantizer/quantize_util.h
+2
-2
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
+2
-2
未找到文件。
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
浏览文件 @
28af1e50
...
@@ -118,6 +118,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
...
@@ -118,6 +118,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
std
::
make_unique
<
schema
::
QuantParamT
>
(
input_quant_params
[
0
]);
std
::
make_unique
<
schema
::
QuantParamT
>
(
input_quant_params
[
0
]);
tensor_input
->
quantParams
.
emplace_back
(
std
::
move
(
input_quant_param
));
tensor_input
->
quantParams
.
emplace_back
(
std
::
move
(
input_quant_param
));
}
}
tensor_input
->
dataType
=
kNumberTypeInt8
;
// output
// output
auto
output_index
=
node
->
outputIndex
[
0
];
auto
output_index
=
node
->
outputIndex
[
0
];
auto
tensor_output
=
metaGraphT
->
allTensors
[
output_index
].
get
();
auto
tensor_output
=
metaGraphT
->
allTensors
[
output_index
].
get
();
...
@@ -129,6 +130,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
...
@@ -129,6 +130,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
std
::
make_unique
<
schema
::
QuantParamT
>
(
output_quant_params
[
0
]);
std
::
make_unique
<
schema
::
QuantParamT
>
(
output_quant_params
[
0
]);
tensor_output
->
quantParams
.
emplace_back
(
std
::
move
(
output_quant_param
));
tensor_output
->
quantParams
.
emplace_back
(
std
::
move
(
output_quant_param
));
}
}
tensor_output
->
dataType
=
kNumberTypeInt8
;
// // TensorType
// // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) {
// if (valuePtr != nullptr) {
...
@@ -210,17 +212,18 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
...
@@ -210,17 +212,18 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
paramTensor
->
data
.
resize
(
paramValue
->
tensor_size
());
paramTensor
->
data
.
resize
(
paramValue
->
tensor_size
());
memcpy
(
paramTensor
->
data
.
data
(),
paramValue
->
tensor_addr
(),
paramValue
->
tensor_size
());
memcpy
(
paramTensor
->
data
.
data
(),
paramValue
->
tensor_addr
(),
paramValue
->
tensor_size
());
}
}
// for (auto &ite : paramValue->quant_param()) {
for
(
auto
&
ite
:
paramValue
->
quant_param
())
{
// auto quantPar = std::make_unique<schema::QuantParamT>();
auto
quantPar
=
std
::
make_unique
<
schema
::
QuantParamT
>
();
// quantPar->scale = ite->scale;
quantPar
->
scale
=
ite
->
scale
;
// quantPar->zeroPoint = ite->zeroPoint;
quantPar
->
zeroPoint
=
ite
->
zeroPoint
;
// quantPar->min = ite->min;
quantPar
->
min
=
ite
->
min
;
// quantPar->max = ite->max;
quantPar
->
max
=
ite
->
max
;
// quantPar->narrowRange = ite->narrowRange;
quantPar
->
narrowRange
=
ite
->
narrowRange
;
// quantPar->inited = ite->inited;
quantPar
->
inited
=
ite
->
inited
;
// quantPar->numBits = ite->numBits;
quantPar
->
numBits
=
ite
->
numBits
;
// paramTensor->quantParams.emplace_back(std::move(quantPar));
paramTensor
->
quantParams
.
emplace_back
(
std
::
move
(
quantPar
));
// }
paramTensor
->
dataType
=
paramValue
->
tensor_type
();
}
nodeIdMap
[
paramNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
nodeIdMap
[
paramNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
...
...
mindspore/lite/src/model_impl.cc
浏览文件 @
28af1e50
...
@@ -140,6 +140,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
...
@@ -140,6 +140,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return
new
lite
::
Flatten
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
return
new
lite
::
Flatten
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_MatMul
:
case
schema
::
PrimitiveType_MatMul
:
return
new
lite
::
MatMul
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
return
new
lite
::
MatMul
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_QuantDTypeCast
:
return
new
lite
::
QuantDTypeCast
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
default:
default:
break
;
break
;
}
}
...
...
mindspore/lite/src/ops/reshape.cc
浏览文件 @
28af1e50
...
@@ -57,6 +57,25 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_
...
@@ -57,6 +57,25 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_
return
RET_OK
;
return
RET_OK
;
}
}
template
<
typename
T
>
void
CalShape
(
const
T
*
data
,
const
std
::
vector
<
tensor
::
Tensor
*>
&
inputs
,
std
::
vector
<
int
>
*
out_shape
,
int
shape_size
)
{
int
input_count
=
inputs
[
0
]
->
ElementsNum
();
int
index
=
0
;
int
size
=
1
;
for
(
size_t
i
=
0
;
i
<
shape_size
;
i
++
)
{
if
(
data
[
i
]
==
-
1
)
{
index
=
i
;
}
else
{
size
*=
data
[
i
];
}
out_shape
->
push_back
(
data
[
i
]);
}
if
(
data
[
index
]
==
-
1
)
{
(
*
out_shape
)[
index
]
=
input_count
/
size
;
}
}
int
Reshape
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
int
Reshape
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
auto
input
=
inputs_
.
front
();
auto
input
=
inputs_
.
front
();
...
@@ -69,31 +88,23 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
...
@@ -69,31 +88,23 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
std
::
vector
<
int
>
out_shape
;
std
::
vector
<
int
>
out_shape
;
if
(
inputs_
.
size
()
==
kDoubleNum
)
{
if
(
inputs_
.
size
()
==
kDoubleNum
)
{
auto
shape_tensor
=
inputs_
.
at
(
1
);
auto
shape_tensor
=
inputs_
.
at
(
1
);
size_t
shape_size
=
shape_tensor
->
shape
().
size
();
size_t
shape_size
=
shape_tensor
->
ElementsNum
();
switch
(
shape_tensor
->
data_type
())
{
switch
(
shape_tensor
->
data_type
())
{
case
kNumberTypeInt8
:
{
case
kNumberTypeInt8
:
{
auto
data
=
reinterpret_cast
<
int8_t
*>
(
shape_tensor
->
Data
());
auto
data
=
reinterpret_cast
<
int8_t
*>
(
shape_tensor
->
Data
());
for
(
size_t
i
=
0
;
i
<
shape_size
;
i
++
)
{
CalShape
<
int8_t
>
(
data
,
inputs_
,
&
out_shape
,
shape_size
);
out_shape
.
push_back
(
data
[
i
]);
}
}
break
;
}
break
;
case
kNumberTypeInt32
:
{
case
kNumberTypeInt32
:
{
auto
data
=
reinterpret_cast
<
int32_t
*>
(
shape_tensor
->
Data
());
auto
data
=
reinterpret_cast
<
int32_t
*>
(
shape_tensor
->
Data
());
for
(
size_t
i
=
0
;
i
<
shape_size
;
i
++
)
{
CalShape
<
int32_t
>
(
data
,
inputs_
,
&
out_shape
,
shape_size
);
out_shape
.
push_back
(
data
[
i
]);
}
}
break
;
}
break
;
case
kNumberTypeFloat
:
{
case
kNumberTypeFloat
:
{
auto
data
=
reinterpret_cast
<
float
*>
(
shape_tensor
->
Data
());
auto
data
=
reinterpret_cast
<
float
*>
(
shape_tensor
->
Data
());
for
(
size_t
i
=
0
;
i
<
shape_size
;
i
++
)
{
CalShape
<
float
>
(
data
,
inputs_
,
&
out_shape
,
shape_size
);
out_shape
.
push_back
(
data
[
i
]);
}
}
break
;
}
break
;
case
kNumberTypeUInt32
:
{
case
kNumberTypeUInt32
:
{
auto
data
=
reinterpret_cast
<
uint32_t
*>
(
shape_tensor
->
Data
());
auto
data
=
reinterpret_cast
<
uint32_t
*>
(
shape_tensor
->
Data
());
for
(
size_t
i
=
0
;
i
<
shape_size
;
i
++
)
{
CalShape
<
uint32_t
>
(
data
,
inputs_
,
&
out_shape
,
shape_size
);
out_shape
.
push_back
(
data
[
i
]);
}
}
break
;
}
break
;
default:
{
default:
{
MS_LOG
(
ERROR
)
<<
"Reshape weight tensor has unsupported dataType: "
<<
shape_tensor
->
data_type
();
MS_LOG
(
ERROR
)
<<
"Reshape weight tensor has unsupported dataType: "
<<
shape_tensor
->
data_type
();
...
...
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
浏览文件 @
28af1e50
...
@@ -215,7 +215,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -215,7 +215,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT
(
node
!=
nullptr
);
MS_ASSERT
(
node
!=
nullptr
);
auto
opType
=
node
->
primitive
->
value
.
type
;
auto
opType
=
node
->
primitive
->
value
.
type
;
if
(
opType
!=
schema
::
PrimitiveType_Conv2D
&&
opType
!=
schema
::
PrimitiveType_DepthwiseConv2D
&&
if
(
opType
!=
schema
::
PrimitiveType_Conv2D
&&
opType
!=
schema
::
PrimitiveType_DepthwiseConv2D
&&
opType
!=
schema
::
PrimitiveType_DeConv2D
)
{
opType
!=
schema
::
PrimitiveType_DeConv2D
&&
opType
!=
schema
::
PrimitiveType_DeDepthwiseConv2D
)
{
return
0
;
return
0
;
}
}
...
@@ -230,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -230,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
MS_LOG
(
DEBUG
)
<<
"**weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
<<
weightTensor
->
format
MS_LOG
(
DEBUG
)
<<
"**weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
<<
weightTensor
->
format
<<
weightTensor
->
dataType
;
<<
weightTensor
->
dataType
;
status
=
TransFilterFormat
<
u
int8_t
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
status
=
TransFilterFormat
<
int8_t
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
}
else
{
}
else
{
MS_LOG
(
DEBUG
)
<<
"--weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
<<
weightTensor
->
format
MS_LOG
(
DEBUG
)
<<
"--weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
<<
weightTensor
->
format
<<
weightTensor
->
dataType
;
<<
weightTensor
->
dataType
;
...
@@ -238,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -238,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
}
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KHWC
)
{
// from onnx
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KHWC
)
{
// from onnx
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
u
int8_t
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
status
=
TransFilterFormat
<
int8_t
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
}
else
{
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
}
}
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc
浏览文件 @
28af1e50
...
@@ -31,14 +31,23 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli
...
@@ -31,14 +31,23 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli
const
auto
&
tfliteAttr
=
tfliteOp
->
builtin_options
.
AsReshapeOptions
();
const
auto
&
tfliteAttr
=
tfliteOp
->
builtin_options
.
AsReshapeOptions
();
if
(
tfliteAttr
==
nullptr
)
{
if
(
tfliteAttr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"get op: "
<<
op
->
name
.
c_str
()
<<
" attr failed"
;
if
(
tfliteOp
->
inputs
.
size
()
<
2
)
{
return
RET_NULL_PTR
;
MS_LOG
(
ERROR
)
<<
"expected two input tensors, but got: "
<<
tfliteOp
->
inputs
.
size
();
}
return
RET_ERROR
;
}
attr
->
format
=
schema
::
Format_NHWC
;
auto
shape_tensor_index
=
tfliteOp
->
inputs
[
1
];
attr
->
shape
.
resize
(
tfliteAttr
->
new_shape
.
size
());
const
auto
&
shape_tensor
=
tfliteTensors
[
shape_tensor_index
];
for
(
size_t
i
=
0
;
i
<
tfliteAttr
->
new_shape
.
size
();
++
i
)
{
std
::
vector
<
tflite
::
TensorT
*>
shape_tensors
{
shape_tensor
.
get
()};
attr
->
shape
[
i
]
=
tfliteAttr
->
new_shape
[
i
];
if
(
RET_OK
!=
ParseWeight
(
shape_tensors
,
tfliteModelBuffer
,
tensor_cache
,
schema
::
Format_KHWC
))
{
MS_LOG
(
ERROR
)
<<
"parse shape tensor error"
;
return
RET_ERROR
;
}
}
else
{
attr
->
format
=
schema
::
Format_NHWC
;
attr
->
shape
.
resize
(
tfliteAttr
->
new_shape
.
size
());
for
(
size_t
i
=
0
;
i
<
tfliteAttr
->
new_shape
.
size
();
++
i
)
{
attr
->
shape
[
i
]
=
tfliteAttr
->
new_shape
[
i
];
}
}
}
if
(
op
!=
nullptr
)
{
if
(
op
!=
nullptr
)
{
...
...
mindspore/lite/tools/converter/quantizer/post_training.cc
浏览文件 @
28af1e50
...
@@ -230,6 +230,13 @@ struct DivergInfo {
...
@@ -230,6 +230,13 @@ struct DivergInfo {
}
else
{
}
else
{
zero_point
=
static_cast
<
int
>
(
std
::
round
(
zero_point_from_min
));
zero_point
=
static_cast
<
int
>
(
std
::
round
(
zero_point_from_min
));
}
}
MS_LOG
(
DEBUG
)
<<
"zero point:"
<<
zero_point
;
if
(
quant_min
==
0
&&
quant_max
==
255
)
{
zero_point
=
128
;
}
else
if
(
quant_min
==
-
128
&&
quant_max
==
127
)
{
zero_point
=
0
;
}
return
std
::
make_pair
(
this
->
cnode
,
zero_point
);
return
std
::
make_pair
(
this
->
cnode
,
zero_point
);
}
}
};
};
...
@@ -466,11 +473,6 @@ Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
...
@@ -466,11 +473,6 @@ Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
PostTrainingQuantizer
::
PostTrainingQuantizer
(
FuncGraphPtr
graph
,
string
path
,
int
bit_num
,
TypeId
target_type
)
PostTrainingQuantizer
::
PostTrainingQuantizer
(
FuncGraphPtr
graph
,
string
path
,
int
bit_num
,
TypeId
target_type
)
:
Quantizer
(
graph
)
{
:
Quantizer
(
graph
)
{
this
->
bit_num
=
bit_num
;
this
->
bit_num
=
bit_num
;
calibrator_
=
std
::
unique_ptr
<
Calibrator
>
(
new
Calibrator
(
path
,
this
->
bit_num
,
quant_max
,
quant_min
));
if
(
calibrator_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"creat calibrator failed!"
;
return
;
}
this
->
target_type_
=
target_type
;
this
->
target_type_
=
target_type
;
if
(
target_type
==
kNumberTypeInt8
)
{
if
(
target_type
==
kNumberTypeInt8
)
{
quant_max
=
(
1
<<
(
this
->
bit_num
-
1
))
-
1
;
// 127
quant_max
=
(
1
<<
(
this
->
bit_num
-
1
))
-
1
;
// 127
...
@@ -481,6 +483,11 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
...
@@ -481,6 +483,11 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"unsupported quant value type: "
<<
target_type
;
MS_LOG
(
ERROR
)
<<
"unsupported quant value type: "
<<
target_type
;
}
}
calibrator_
=
std
::
unique_ptr
<
Calibrator
>
(
new
Calibrator
(
path
,
this
->
bit_num
,
quant_max
,
quant_min
));
if
(
calibrator_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"creat calibrator failed!"
;
return
;
}
}
}
STATUS
PostTrainingQuantizer
::
DoQuantInput
(
double
scale
,
int
zeropoint
,
struct
MaxMin
*
max_min
,
STATUS
PostTrainingQuantizer
::
DoQuantInput
(
double
scale
,
int
zeropoint
,
struct
MaxMin
*
max_min
,
...
@@ -526,7 +533,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
...
@@ -526,7 +533,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
}
}
auto
parameter
=
std
::
dynamic_pointer_cast
<
Parameter
>
(
node
);
auto
parameter
=
std
::
dynamic_pointer_cast
<
Parameter
>
(
node
);
ParamValueLitePtr
paramValue
=
std
::
dynamic_pointer_cast
<
ParamValueLite
>
(
parameter
->
default_param
());
ParamValueLitePtr
paramValue
=
std
::
dynamic_pointer_cast
<
ParamValueLite
>
(
parameter
->
default_param
());
auto
status
=
QuantFilter
(
paramValue
,
QuantType_PostTraining
,
bit_num
);
auto
status
=
QuantFilter
(
paramValue
,
QuantType_PostTraining
,
quant_max
,
quant_min
,
bit_num
);
if
(
status
!=
RET_OK
)
{
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"QuantFilter failed: "
<<
status
;
MS_LOG
(
ERROR
)
<<
"QuantFilter failed: "
<<
status
;
return
status
;
return
status
;
...
...
mindspore/lite/tools/converter/quantizer/post_training.h
浏览文件 @
28af1e50
...
@@ -64,7 +64,7 @@ class PostTrainingQuantizer : public Quantizer {
...
@@ -64,7 +64,7 @@ class PostTrainingQuantizer : public Quantizer {
int
quant_min
{
0
};
int
quant_min
{
0
};
private:
private:
TypeId
target_type_
{
kNumberType
U
Int8
};
TypeId
target_type_
{
kNumberTypeInt8
};
std
::
unique_ptr
<
Calibrator
>
calibrator_
;
std
::
unique_ptr
<
Calibrator
>
calibrator_
;
...
...
mindspore/lite/tools/converter/quantizer/quant_cast.cc
浏览文件 @
28af1e50
...
@@ -22,13 +22,16 @@
...
@@ -22,13 +22,16 @@
namespace
mindspore
::
lite
::
quant
{
namespace
mindspore
::
lite
::
quant
{
ValueNodePtr
NewQuantCastValueNode
(
int
src_type
,
int
dst_type
)
{
ValueNodePtr
NewQuantCastValueNode
(
int
src_type
,
int
dst_type
,
const
std
::
vector
<
schema
::
QuantParamT
>
&
quant_params
)
{
std
::
unique_ptr
<
schema
::
PrimitiveT
>
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
std
::
unique_ptr
<
schema
::
PrimitiveT
>
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
schema
::
QuantDTypeCastT
quant_dtype_cast
;
schema
::
QuantDTypeCastT
quant_dtype_cast
;
quant_dtype_cast
.
srcT
=
src_type
;
// kNumberTypeUInt8;
quant_dtype_cast
.
srcT
=
src_type
;
// kNumberTypeUInt8;
quant_dtype_cast
.
dstT
=
dst_type
;
// kNumberTypeFloat32;
quant_dtype_cast
.
dstT
=
dst_type
;
// kNumberTypeFloat32;
primitive
->
value
.
Set
(
quant_dtype_cast
);
primitive
->
value
.
Set
(
quant_dtype_cast
);
auto
primTValue
=
std
::
make_shared
<
PrimitiveTValue
>
(
primitive
.
release
());
auto
primTValue
=
std
::
make_shared
<
PrimitiveTValue
>
(
primitive
.
release
());
for
(
auto
&
quant_param
:
quant_params
)
{
primTValue
->
AddInputQuantParam
(
quant_param
);
}
return
NewValueNode
(
primTValue
);
return
NewValueNode
(
primTValue
);
}
}
...
@@ -48,7 +51,8 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
...
@@ -48,7 +51,8 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
}
}
if
(
first
)
{
if
(
first
)
{
if
(
curnode_quant_type
==
schema
::
QuantType_PostTraining
&&
inputDataDType
==
kNumberTypeFloat32
)
{
if
(
curnode_quant_type
==
schema
::
QuantType_PostTraining
&&
inputDataDType
==
kNumberTypeFloat32
)
{
auto
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberTypeUInt8
);
auto
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberTypeUInt8
,
primitiveT_value
->
GetInputQuantParams
());
std
::
vector
<
AnfNodePtr
>
op_inputs
=
{
value_node
,
cnode
->
input
(
1
)};
std
::
vector
<
AnfNodePtr
>
op_inputs
=
{
value_node
,
cnode
->
input
(
1
)};
auto
quant_cast_cnode
=
graph
->
NewCNode
(
op_inputs
);
auto
quant_cast_cnode
=
graph
->
NewCNode
(
op_inputs
);
quant_cast_cnode
->
set_fullname_with_scope
(
cnode
->
fullname_with_scope
()
+
"_quant_cast"
);
quant_cast_cnode
->
set_fullname_with_scope
(
cnode
->
fullname_with_scope
()
+
"_quant_cast"
);
...
@@ -78,10 +82,12 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
...
@@ -78,10 +82,12 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
ValueNodePtr
value_node
=
nullptr
;
ValueNodePtr
value_node
=
nullptr
;
if
(
curnode_quant_type
==
schema
::
QuantType_PostTraining
&&
if
(
curnode_quant_type
==
schema
::
QuantType_PostTraining
&&
input_cnode_quant_type
==
schema
::
QuantType_QUANT_NONE
)
{
input_cnode_quant_type
==
schema
::
QuantType_QUANT_NONE
)
{
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberTypeUInt8
);
value_node
=
NewQuantCastValueNode
(
kNumberTypeFloat32
,
kNumberTypeUInt8
,
input_cnode_primitiveT_value
->
GetInputQuantParams
());
}
else
if
(
curnode_quant_type
==
schema
::
QuantType_QUANT_NONE
&&
}
else
if
(
curnode_quant_type
==
schema
::
QuantType_QUANT_NONE
&&
input_cnode_quant_type
==
schema
::
QuantType_PostTraining
)
{
input_cnode_quant_type
==
schema
::
QuantType_PostTraining
)
{
value_node
=
NewQuantCastValueNode
(
kNumberTypeUInt8
,
kNumberTypeFloat32
);
value_node
=
NewQuantCastValueNode
(
kNumberTypeUInt8
,
kNumberTypeFloat32
,
input_cnode_primitiveT_value
->
GetInputQuantParams
());
}
}
if
(
value_node
==
nullptr
)
{
if
(
value_node
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"value_node is null! "
MS_LOG
(
WARNING
)
<<
"value_node is null! "
...
...
mindspore/lite/tools/converter/quantizer/quantize_util.cc
浏览文件 @
28af1e50
...
@@ -190,7 +190,7 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
...
@@ -190,7 +190,7 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
}
}
STATUS
CalQuantizationParams
(
std
::
unique_ptr
<
AnfQuantParam
>
&
quantParam
,
double
mMin
,
STATUS
CalQuantizationParams
(
std
::
unique_ptr
<
AnfQuantParam
>
&
quantParam
,
double
mMin
,
double
mMax
,
bool
narrowRange
,
int
numB
its
)
{
double
mMax
,
bool
narrowRange
,
int
quant_max
,
int
quant_min
,
int
num_b
its
)
{
MS_ASSERT
(
quantParam
!=
nullptr
);
MS_ASSERT
(
quantParam
!=
nullptr
);
if
(
mMin
>
0.0
f
)
{
if
(
mMin
>
0.0
f
)
{
MS_LOG
(
ERROR
)
<<
"min "
<<
mMin
<<
" is bigger then 0, set to 0, this may course low precision"
;
MS_LOG
(
ERROR
)
<<
"min "
<<
mMin
<<
" is bigger then 0, set to 0, this may course low precision"
;
...
@@ -215,28 +215,17 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
...
@@ -215,28 +215,17 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
quantParam
->
scale
=
0.0
f
;
quantParam
->
scale
=
0.0
f
;
quantParam
->
zeroPoint
=
0
;
quantParam
->
zeroPoint
=
0
;
quantParam
->
narrowRange
=
narrowRange
;
quantParam
->
narrowRange
=
narrowRange
;
quantParam
->
numBits
=
num
B
its
;
quantParam
->
numBits
=
num
_b
its
;
return
RET_OK
;
return
RET_OK
;
}
}
int
quantMin
=
narrowRange
?
1
:
0
;
auto
quantMinFloat
=
static_cast
<
double
>
(
quant_min
);
int
quantMax
=
(
1
<<
(
unsigned
int
)
numBits
)
-
1
;
auto
quantMaxFloat
=
static_cast
<
double
>
(
quant_max
);
auto
quantMinFloat
=
static_cast
<
double
>
(
quantMin
);
auto
quantMaxFloat
=
static_cast
<
double
>
(
quantMax
);
double
scale
=
(
mMax
-
mMin
)
/
(
quantMaxFloat
-
quantMinFloat
);
double
scale
=
(
mMax
-
mMin
)
/
(
quantMaxFloat
-
quantMinFloat
);
const
double
zeroPointFromMin
=
quantMinFloat
-
mMin
/
scale
;
const
double
zeroPointFromMin
=
quantMinFloat
-
mMin
/
scale
;
const
double
zeroPointFromMax
=
quantMaxFloat
-
mMax
/
scale
;
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
const
double
zpFromMinError
=
std
::
abs
(
quantMinFloat
)
+
std
::
abs
(
mMin
/
scale
);
int
zeroPoint
=
static_cast
<
int32_t
>
(
std
::
round
(
zeroPointFromMin
));
const
double
zpFromMaxError
=
std
::
abs
(
quantMaxFloat
)
+
std
::
abs
(
mMax
/
scale
);
const
double
zpDouble
=
zpFromMinError
<
zpFromMaxError
?
zeroPointFromMin
:
zeroPointFromMax
;
int
zeroPoint
;
if
(
zpDouble
<
quantMinFloat
)
{
zeroPoint
=
quantMin
;
}
else
if
(
zpDouble
>
quantMaxFloat
)
{
zeroPoint
=
quantMax
;
}
else
{
zeroPoint
=
static_cast
<
int32_t
>
(
std
::
round
(
zpDouble
));
}
// The zero point should always be in the range of quantized value,
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
// [qmin, qmax].
MS_ASSERT
(
zeroPoint
>=
quantMin
);
MS_ASSERT
(
zeroPoint
>=
quantMin
);
...
@@ -247,12 +236,12 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
...
@@ -247,12 +236,12 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
quantParam
->
scale
=
scale
;
quantParam
->
scale
=
scale
;
quantParam
->
zeroPoint
=
zeroPoint
;
quantParam
->
zeroPoint
=
zeroPoint
;
quantParam
->
narrowRange
=
narrowRange
;
quantParam
->
narrowRange
=
narrowRange
;
quantParam
->
numBits
=
num
B
its
;
quantParam
->
numBits
=
num
_b
its
;
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
,
size_t
bitNum
)
{
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
,
int
quant_max
,
int
quant_min
,
size_t
bitNum
)
{
auto
dims
=
weightPtr
->
tensor_shape
();
auto
dims
=
weightPtr
->
tensor_shape
();
if
(
dims
.
size
()
<
1
)
{
if
(
dims
.
size
()
<
1
)
{
MS_LOG
(
ERROR
)
<<
"weight dims size error"
;
MS_LOG
(
ERROR
)
<<
"weight dims size error"
;
...
@@ -284,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
...
@@ -284,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
}
}
std
::
unique_ptr
<
AnfQuantParam
>
quantParam
=
std
::
unique_ptr
<
AnfQuantParam
>
(
new
AnfQuantParam
);
std
::
unique_ptr
<
AnfQuantParam
>
quantParam
=
std
::
unique_ptr
<
AnfQuantParam
>
(
new
AnfQuantParam
);
STATUS
status
=
CalQuantizationParams
(
quantParam
,
min
,
max
,
false
,
bitNum
);
STATUS
status
=
CalQuantizationParams
(
quantParam
,
min
,
max
,
false
,
quant_max
,
quant_min
,
bitNum
);
if
(
status
!=
RET_OK
)
{
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"CalQuantizationParams failed"
<<
status
;
MS_LOG
(
ERROR
)
<<
"CalQuantizationParams failed"
<<
status
;
return
status
;
return
status
;
...
@@ -308,8 +297,8 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
...
@@ -308,8 +297,8 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
PostBitPack
(
const_cast
<
float
*>
(
rawDatas
),
shapeSize
,
bitNum
);
PostBitPack
(
const_cast
<
float
*>
(
rawDatas
),
shapeSize
,
bitNum
);
}
}
weightPtr
->
set_tensor_type
(
kNumberType
U
Int8
);
weightPtr
->
set_tensor_type
(
kNumberTypeInt8
);
weightPtr
->
set_tensor_size
(
shapeSize
*
sizeof
(
u
int8_t
));
weightPtr
->
set_tensor_size
(
shapeSize
*
sizeof
(
int8_t
));
return
RET_OK
;
return
RET_OK
;
}
}
...
...
mindspore/lite/tools/converter/quantizer/quantize_util.h
浏览文件 @
28af1e50
...
@@ -60,7 +60,7 @@ class QuantStrategy {
...
@@ -60,7 +60,7 @@ class QuantStrategy {
};
};
STATUS
CalQuantizationParams
(
std
::
unique_ptr
<
AnfQuantParam
>
&
quantParam
,
double
mMin
,
double
mMax
,
STATUS
CalQuantizationParams
(
std
::
unique_ptr
<
AnfQuantParam
>
&
quantParam
,
double
mMin
,
double
mMax
,
bool
narrowRange
=
false
,
int
numBits
=
UINT8_QUANTIZATION
);
bool
narrowRange
,
int
quant_max
,
int
quant_min
,
int
num_bits
);
template
<
typename
T
>
template
<
typename
T
>
T
QuantizeData
(
const
float
originData
,
const
AnfQuantParam
*
quantParam
)
{
T
QuantizeData
(
const
float
originData
,
const
AnfQuantParam
*
quantParam
)
{
...
@@ -96,7 +96,7 @@ T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
...
@@ -96,7 +96,7 @@ T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
void
CalFakeNode
(
const
AnfNodePtr
&
inTensor
);
void
CalFakeNode
(
const
AnfNodePtr
&
inTensor
);
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
=
QuantType_AwareTraining
,
STATUS
QuantFilter
(
ParamValueLitePtr
&
weightPtr
,
QuantType
quantType
,
int
quant_max
,
int
quant_min
,
size_t
bitNum
=
UINT8_QUANTIZATION
);
size_t
bitNum
=
UINT8_QUANTIZATION
);
STATUS
PostBitPack
(
float
*
weights
,
size_t
shapeSize
,
size_t
bitNum
=
UINT8_QUANTIZATION
);
STATUS
PostBitPack
(
float
*
weights
,
size_t
shapeSize
,
size_t
bitNum
=
UINT8_QUANTIZATION
);
...
...
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
浏览文件 @
28af1e50
...
@@ -81,7 +81,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
...
@@ -81,7 +81,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
}
}
ParamValueLitePtr
paramValue
=
std
::
static_pointer_cast
<
ParamValueLite
>
(
paramNode
->
default_param
());
ParamValueLitePtr
paramValue
=
std
::
static_pointer_cast
<
ParamValueLite
>
(
paramNode
->
default_param
());
auto
status
=
QuantFilter
(
paramValue
,
QuantType_WeightQuant
,
bitNum
);
auto
status
=
QuantFilter
(
paramValue
,
QuantType_WeightQuant
,
127
,
-
128
,
bitNum
);
if
(
status
!=
RET_OK
)
{
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"QuantFilter failed : "
<<
status
;
MS_LOG
(
ERROR
)
<<
"QuantFilter failed : "
<<
status
;
return
status
;
return
status
;
...
@@ -120,7 +120,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
...
@@ -120,7 +120,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
MS_LOG
(
ERROR
)
<<
"No valid input param node !"
;
MS_LOG
(
ERROR
)
<<
"No valid input param node !"
;
continue
;
continue
;
}
}
auto
status
=
QuantFilter
(
paramValue
,
QuantType_WeightQuant
,
bitNum
);
auto
status
=
QuantFilter
(
paramValue
,
QuantType_WeightQuant
,
127
,
-
128
,
bitNum
);
if
(
status
!=
RET_OK
)
{
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"QunatFilter failed"
<<
status
;
MS_LOG
(
ERROR
)
<<
"QunatFilter failed"
<<
status
;
return
RET_ERROR
;
return
RET_ERROR
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录