Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6e759cd4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e759cd4
编写于
8月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3815 compute threshold only once in post training quantization
Merge pull request !3815 from xutianchun/quant_0731
上级
6c4ee3f3
fae78e11
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
35 addition
and
18 deletion
+35
-18
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+8
-11
mindspore/lite/src/ir/primitive_t_value.h
mindspore/lite/src/ir/primitive_t_value.h
+2
-2
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
...lite/tools/converter/optimizer/node/weight_format_pass.cc
+3
-2
mindspore/lite/tools/converter/quantizer/post_training.cc
mindspore/lite/tools/converter/quantizer/post_training.cc
+22
-3
未找到文件。
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
浏览文件 @
6e759cd4
...
...
@@ -98,7 +98,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
}
node
->
primitive
=
std
::
unique_ptr
<
schema
::
PrimitiveT
>
(
primitiveT_value
->
GetPrimitiveT
());
primitiveT_value
->
SetPrimitiveT
(
nullptr
);
std
::
vector
<
schema
::
TensorT
*>
outputs
;
SetOpInputNode
(
cnode
,
metaGraphT
.
get
(),
node
.
get
());
SetOpOutputNode
(
outputs
,
metaGraphT
.
get
(),
node
.
get
());
...
...
@@ -113,24 +112,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto
input_quant_params
=
primitiveT_value
->
GetInputQuantParams
();
if
(
input_quant_params
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"node: "
<<
node
->
name
<<
" input quant params is empty"
;
continue
;
}
else
{
std
::
unique_ptr
<
schema
::
QuantParamT
>
input_quant_param
=
std
::
make_unique
<
schema
::
QuantParamT
>
(
input_quant_params
[
0
]);
tensor_input
->
quantParams
.
emplace_back
(
std
::
move
(
input_quant_param
));
}
std
::
unique_ptr
<
schema
::
QuantParamT
>
input_quant_param
=
std
::
make_unique
<
schema
::
QuantParamT
>
(
input_quant_params
[
0
]);
tensor_input
->
quantParams
.
emplace_back
(
std
::
move
(
input_quant_param
));
// output
auto
output_index
=
node
->
outputIndex
[
0
];
auto
tensor_output
=
metaGraphT
->
allTensors
[
output_index
].
get
();
auto
output_quant_params
=
primitiveT_value
->
GetOutputQuantParams
();
if
(
output_quant_params
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"node: "
<<
node
->
name
<<
" output quant params is empty"
;
continue
;
}
else
{
std
::
unique_ptr
<
schema
::
QuantParamT
>
output_quant_param
=
std
::
make_unique
<
schema
::
QuantParamT
>
(
output_quant_params
[
0
]);
tensor_output
->
quantParams
.
emplace_back
(
std
::
move
(
output_quant_param
));
}
std
::
unique_ptr
<
schema
::
QuantParamT
>
output_quant_param
=
std
::
make_unique
<
schema
::
QuantParamT
>
(
output_quant_params
[
0
]);
tensor_output
->
quantParams
.
emplace_back
(
std
::
move
(
output_quant_param
));
// // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) {
...
...
mindspore/lite/src/ir/primitive_t_value.h
浏览文件 @
6e759cd4
...
...
@@ -26,8 +26,8 @@ namespace mindspore::lite {
class
PrimitiveTValue
:
public
Value
{
public:
explicit
PrimitiveTValue
(
schema
::
PrimitiveT
*
primt
)
:
primitive
(
primt
)
{}
~
PrimitiveTValue
()
override
{
delete
this
->
primitive
;
}
// not responsible to free primitive, the one created the dynamic memory is responsible to free it.
~
PrimitiveTValue
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimitiveTValue
,
Value
)
...
...
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
浏览文件 @
6e759cd4
...
...
@@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
MS_LOG
(
ERROR
)
<<
"ShapeFormatTrans failed: "
<<
status
;
return
status
;
}
if
(
this
->
quantType
==
QuantType_AwareTrainning
)
{
if
(
this
->
quantType
==
QuantType_AwareTrainning
||
this
->
quantType
==
QuantType_PostTraining
)
{
status
=
QuantDataFormatTrans
(
graphNode
);
if
(
status
!=
0
)
{
MS_LOG
(
ERROR
)
<<
"QuantDataFormatTrans failed: "
<<
status
;
...
...
@@ -147,7 +147,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
}
else
if
(
fmkType
==
converter
::
FmkType_TFLITE
)
{
switch
(
node
->
quantType
)
{
case
QuantType_QUANT_NONE
:
case
QuantType_AwareTrainning
:
{
case
QuantType_AwareTrainning
:
case
QuantType_PostTraining
:
{
if
(
opType
==
schema
::
PrimitiveType_Conv2D
)
{
weightTensor
->
format
=
schema
::
Format_KHWC
;
}
else
if
(
opType
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
...
...
mindspore/lite/tools/converter/quantizer/post_training.cc
浏览文件 @
6e759cd4
...
...
@@ -292,13 +292,32 @@ STATUS Calibrator::RecordMaxValue(std::string opName, vector<float> data,
}
STATUS
Calibrator
::
ComputeThreshold
()
{
for
(
auto
iter
=
this
->
input_diverg_info_
.
begin
();
iter
!=
this
->
in
put_diverg_info_
.
end
();
iter
++
)
{
for
(
auto
iter
=
this
->
output_diverg_info_
.
begin
();
iter
!=
this
->
out
put_diverg_info_
.
end
();
iter
++
)
{
DivergInfo
*
info
=
iter
->
second
.
get
();
info
->
ComputeThreshold
();
}
for
(
auto
iter
=
this
->
output_diverg_info_
.
begin
();
iter
!=
this
->
output_diverg_info_
.
end
();
iter
++
)
{
// node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
for
(
auto
iter
=
this
->
input_diverg_info_
.
begin
();
iter
!=
this
->
input_diverg_info_
.
end
();
iter
++
)
{
DivergInfo
*
info
=
iter
->
second
.
get
();
info
->
ComputeThreshold
();
auto
cnode
=
info
->
cnode
;
bool
already_computed
=
false
;
auto
input
=
cnode
->
input
(
1
);
if
(
input
->
isa
<
mindspore
::
CNode
>
())
{
auto
input_cnode
=
std
::
dynamic_pointer_cast
<
mindspore
::
CNode
>
(
input
);
for
(
const
auto
&
output_diverg_info
:
output_diverg_info_
)
{
auto
output_diverg_cnode
=
output_diverg_info
.
second
->
cnode
;
if
(
output_diverg_cnode
==
input_cnode
)
{
*
info
=
*
(
output_diverg_info
.
second
);
info
->
cnode
=
cnode
;
already_computed
=
true
;
break
;
}
}
}
if
(
!
already_computed
)
{
info
->
ComputeThreshold
();
}
}
return
RET_OK
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录