Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
406ce735
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
406ce735
编写于
8月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4624 adjust MS model quant param
Merge pull request !4624 from yankai10/merge
上级
95afdb32
921e2cdb
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
220 addition
and
111 deletion
+220
-111
mindspore/lite/src/ir/primitive_t_value.h
mindspore/lite/src/ir/primitive_t_value.h
+9
-1
mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc
mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc
+8
-2
mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc
mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc
+8
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c
mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c
+13
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h
mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h
+2
-0
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc
...te/tools/anf_importer/anf_populater/anf_conv_populater.cc
+17
-10
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h
...ite/tools/anf_importer/anf_populater/anf_conv_populater.h
+15
-8
mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc
...f_importer/anf_populater/anf_depthwiseconv2d_populater.cc
+17
-11
mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h
...nf_importer/anf_populater/anf_depthwiseconv2d_populater.h
+6
-2
mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc
.../tools/anf_importer/anf_populater/anf_matmul_populater.cc
+13
-10
mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h
...e/tools/anf_importer/anf_populater/anf_matmul_populater.h
+6
-2
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
+106
-63
未找到文件。
mindspore/lite/src/ir/primitive_t_value.h
浏览文件 @
406ce735
...
...
@@ -47,7 +47,15 @@ class PrimitiveTValue : public Value {
}
}
void
SetInputQuantParam
(
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vec_quant_param
)
{}
void
SetInputQuantParam
(
const
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
&
input_quant_param
)
{
this
->
input_quant_param_
=
input_quant_param
;
}
void
SetOutputQuantParam
(
const
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
&
output_quant_param
)
{
this
->
output_quant_param_
=
output_quant_param
;
}
void
AddInputQuantParam
(
std
::
vector
<
schema
::
QuantParamT
>
quant_param
)
{
this
->
input_quant_param_
.
emplace_back
(
quant_param
);
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc
浏览文件 @
406ce735
...
...
@@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() {
auto
output
=
out_tensors_
[
0
];
if
(
input
->
shape
().
size
()
==
4
)
{
PackNCHWToNHWCFp32
(
input
->
Data
(),
output
->
Data
(),
output
->
Batch
(),
output
->
Height
()
*
output
->
Width
(),
output
->
Channel
());
if
(
input
->
data_type
()
==
kNumberTypeFloat32
)
{
PackNCHWToNHWCFp32
(
input
->
Data
(),
output
->
Data
(),
output
->
Batch
(),
output
->
Height
()
*
output
->
Width
(),
output
->
Channel
());
}
else
if
(
input
->
data_type
()
==
kNumberTypeInt8
)
{
PackNCHWToNHWCInt8
(
input
->
Data
(),
output
->
Data
(),
output
->
Batch
(),
output
->
Height
()
*
output
->
Width
(),
output
->
Channel
());
}
}
else
{
memcpy
(
output
->
Data
(),
input
->
Data
(),
input
->
ElementsNum
()
*
sizeof
(
float
));
}
...
...
@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_Nchw2Nhwc
,
CpuNchw2NhwcFp32KernelCreator
)
REG_KERNEL
(
kCPU
,
kNumberTypeInt8
,
PrimitiveType_Nchw2Nhwc
,
CpuNchw2NhwcFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc
浏览文件 @
406ce735
...
...
@@ -37,8 +37,13 @@ int Nhwc2NchwCPUKernel::Run() {
auto
output
=
out_tensors_
[
0
];
if
(
input
->
shape
().
size
()
==
4
)
{
PackNHWCToNCHWFp32
(
input
->
Data
(),
output
->
Data
(),
output
->
Batch
(),
output
->
Height
()
*
output
->
Width
(),
output
->
Channel
());
if
(
input
->
data_type
()
==
kNumberTypeFloat32
)
{
PackNHWCToNCHWFp32
(
input
->
Data
(),
output
->
Data
(),
output
->
Batch
(),
output
->
Height
()
*
output
->
Width
(),
output
->
Channel
());
}
else
if
(
input
->
data_type
()
==
kNumberTypeInt8
)
{
PackNHWCToNCHWInt8
(
input
->
Data
(),
output
->
Data
(),
output
->
Batch
(),
output
->
Height
()
*
output
->
Width
(),
output
->
Channel
());
}
}
else
{
memcpy
(
output
->
Data
(),
input
->
Data
(),
input
->
ElementsNum
()
*
sizeof
(
float
));
}
...
...
@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_Nhwc2Nchw
,
CpuNhwc2NchwFp32KernelCreator
)
REG_KERNEL
(
kCPU
,
kNumberTypeInt8
,
PrimitiveType_Nhwc2Nchw
,
CpuNhwc2NchwFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c
浏览文件 @
406ce735
...
...
@@ -978,6 +978,19 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
return
;
}
void
PackNHWCToNCHWInt8
(
const
void
*
src
,
void
*
dst
,
int
batch
,
int
plane
,
int
channel
)
{
for
(
int
n
=
0
;
n
<
batch
;
n
++
)
{
for
(
int
c
=
0
;
c
<
channel
;
c
++
)
{
for
(
int
hw
=
0
;
hw
<
plane
;
hw
++
)
{
int
nhwc_index
=
n
*
channel
*
plane
+
hw
*
channel
+
c
;
int
nchw_index
=
n
*
channel
*
plane
+
c
*
plane
+
hw
;
((
int8_t
*
)
dst
)[
nchw_index
]
=
((
int8_t
*
)
src
)[
nhwc_index
];
}
}
}
return
;
}
void
PackNCHWToNHWCFp32
(
const
void
*
src
,
void
*
dst
,
int
batch
,
int
plane
,
int
channel
)
{
return
PackNHWCToNCHWFp32
(
src
,
dst
,
batch
,
channel
,
plane
);
}
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h
浏览文件 @
406ce735
...
...
@@ -60,6 +60,8 @@ void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int c
void
PackNHWCToNCHWFp32
(
const
void
*
src
,
void
*
dst
,
int
batch
,
int
plane
,
int
channel
);
void
PackNHWCToNCHWInt8
(
const
void
*
src
,
void
*
dst
,
int
batch
,
int
plane
,
int
channel
);
void
PackNCHWToNHWCFp32
(
const
void
*
src
,
void
*
dst
,
int
batch
,
int
plane
,
int
channel
);
void
PackNHWC4ToNHWCFp32
(
const
void
*
src
,
void
*
dst
,
int
batch
,
int
plane
,
int
channel
);
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc
浏览文件 @
406ce735
...
...
@@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f
*
mMax
=
static_cast
<
float
>
((
qmax
-
mean
)
/
stdDev
);
}
void
AnfConvPopulater
::
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecQuantParam
)
{
void
AnfConvPopulater
::
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecInputQuantParam
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecOutputQuantParam
)
{
auto
narrow_range
=
prim
->
GetAttr
(
"narrow_range"
);
bool
narrowRangeQuantParam
=
GetValue
<
bool
>
(
narrow_range
);
auto
num_bits
=
prim
->
GetAttr
(
"num_bits"
);
...
...
@@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant
::
CalQuantizationParams
(
&
quantParam
,
quantParam
.
min
,
quantParam
.
max
,
narrowRangeQuantParam
,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
quants
.
clear
();
int
biasQuantSize
=
0
;
...
...
@@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
}
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
}
quants
.
clear
();
...
...
@@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam
.
min
=
0.0
;
quantParam
.
max
=
0.0
;
quantParam
.
zeroPoint
=
0
;
quantParam
.
scale
=
vecQuantParam
->
at
(
0
).
at
(
0
).
scale
*
vecQuantParam
->
at
(
1
).
at
(
i
).
scale
;
quantParam
.
scale
=
vecInputQuantParam
->
at
(
0
).
at
(
0
).
scale
*
vecInputQuantParam
->
at
(
1
).
at
(
i
).
scale
;
quants
.
emplace_back
(
quantParam
);
}
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
quants
.
clear
();
auto
outputMin
=
prim
->
GetAttr
(
"output_minq"
);
...
...
@@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant
::
CalQuantizationParams
(
&
quantParam
,
quantParam
.
min
,
quantParam
.
max
,
narrowRangeQuantParam
,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
vecQuantParam
->
emplace_back
(
quants
);
vec
Output
QuantParam
->
emplace_back
(
quants
);
}
}
...
...
@@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
PopulaterConv2DSingleGroup
(
prim
,
primitive
,
group
);
}
primitiveTValuePtr
->
SetPrimitiveT
(
primitive
.
release
());
if
(
primitiveTValuePtr
->
GetQuantType
()
==
schema
::
QuantType_AwareTraining
)
{
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecQuantParam
;
PopulaterQuantParam
(
prim
,
&
vecQuantParam
);
primitiveTValuePtr
->
SetInputQuantParam
(
vecQuantParam
);
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecInputQuantParam
;
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecOutputQuantParam
;
PopulaterQuantParam
(
prim
,
&
vecInputQuantParam
,
&
vecOutputQuantParam
);
primitiveTValuePtr
->
SetInputQuantParam
(
vecInputQuantParam
);
primitiveTValuePtr
->
SetOutputQuantParam
(
vecOutputQuantParam
);
}
return
0
;
}
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h
浏览文件 @
406ce735
...
...
@@ -20,9 +20,10 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
#include <memory>
#include <vector>
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
namespace
mindspore
::
lite
{
class
AnfConvPopulater
:
public
AnfNodePopulater
{
public:
...
...
@@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater {
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
private:
void
PopulaterConv2DMultiGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
);
void
PopulaterConv2DSingleGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
);
void
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecQuantParam
);
void
CalQuantParam
(
const
double
&
mean
,
const
double
&
stdDev
,
float
*
mMin
,
float
*
mMax
);
void
PopulaterConv2DMultiGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
);
void
PopulaterConv2DSingleGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
);
void
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecInputQuantParam
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecOutputQuantParam
);
void
CalQuantParam
(
const
double
&
mean
,
const
double
&
stdDev
,
float
*
mMin
,
float
*
mMax
);
};
}
// namespace mindspore::lite
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc
浏览文件 @
406ce735
...
...
@@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &
*
mMax
=
static_cast
<
float
>
((
qmax
-
mean
)
/
stdDev
);
}
void
AnfDepwiseconv2DPopulater
::
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecQuantParam
)
{
void
AnfDepwiseconv2DPopulater
::
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecInputQuantParam
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecOutputQuantParam
)
{
auto
narrow_range
=
prim
->
GetAttr
(
"narrow_range"
);
bool
narrowRangeQuantParam
=
GetValue
<
bool
>
(
narrow_range
);
auto
num_bits
=
prim
->
GetAttr
(
"num_bits"
);
...
...
@@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant
::
CalQuantizationParams
(
&
quantParam
,
quantParam
.
min
,
quantParam
.
max
,
narrowRangeQuantParam
,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
quants
.
clear
();
int
biasQuantSize
=
0
;
...
...
@@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
}
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
}
quants
.
clear
();
...
...
@@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam
.
min
=
0.0
;
quantParam
.
max
=
0.0
;
quantParam
.
zeroPoint
=
0
;
quantParam
.
scale
=
vecQuantParam
->
at
(
0
).
at
(
0
).
scale
*
vecQuantParam
->
at
(
1
).
at
(
i
).
scale
;
quantParam
.
scale
=
vecInputQuantParam
->
at
(
0
).
at
(
0
).
scale
*
vecInputQuantParam
->
at
(
1
).
at
(
i
).
scale
;
quants
.
emplace_back
(
quantParam
);
}
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
quants
.
clear
();
auto
outputMin
=
prim
->
GetAttr
(
"output_minq"
);
...
...
@@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant
::
CalQuantizationParams
(
&
quantParam
,
quantParam
.
min
,
quantParam
.
max
,
narrowRangeQuantParam
,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
vecQuantParam
->
emplace_back
(
quants
);
vec
Output
QuantParam
->
emplace_back
(
quants
);
}
}
...
...
@@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
MS_ASSERT
(
primitiveTValuePtr
!=
nullptr
);
primitiveTValuePtr
->
SetPrimitiveT
(
primitive
.
release
());
if
(
primitiveTValuePtr
->
GetQuantType
())
{
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecQuantParam
;
PopulaterQuantParam
(
prim
,
&
vecQuantParam
);
primitiveTValuePtr
->
SetInputQuantParam
(
vecQuantParam
);
if
(
primitiveTValuePtr
->
GetQuantType
()
==
schema
::
QuantType_AwareTraining
)
{
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecInputQuantParam
;
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecOutputQuantParam
;
PopulaterQuantParam
(
prim
,
&
vecInputQuantParam
,
&
vecOutputQuantParam
);
primitiveTValuePtr
->
SetInputQuantParam
(
vecInputQuantParam
);
primitiveTValuePtr
->
SetOutputQuantParam
(
vecOutputQuantParam
);
}
return
0
;
}
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h
浏览文件 @
406ce735
...
...
@@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
private:
void
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecQuantParam
);
void
CalQuantParam
(
const
double
&
mean
,
const
double
&
stdDev
,
float
*
mMin
,
float
*
mMax
);
void
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecInputQuantParam
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecOutputQuantParam
);
void
CalQuantParam
(
const
double
&
mean
,
const
double
&
stdDev
,
float
*
mMin
,
float
*
mMax
);
};
}
// namespace mindspore::lite
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc
浏览文件 @
406ce735
...
...
@@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
*
mMax
=
static_cast
<
float
>
((
qmax
-
mean
)
/
stdDev
);
}
void
AnfMatmulPopulater
::
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecQuantParam
)
{
void
AnfMatmulPopulater
::
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecInputQuantParam
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecOutputQuantParam
)
{
auto
narrow_range
=
prim
->
GetAttr
(
"narrow_range"
);
bool
narrowRangeQuantParam
=
GetValue
<
bool
>
(
narrow_range
);
auto
num_bits
=
prim
->
GetAttr
(
"num_bits"
);
...
...
@@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant
::
CalQuantizationParams
(
&
quantParam
,
quantParam
.
min
,
quantParam
.
max
,
narrowRangeQuantParam
,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
quants
.
clear
();
auto
filterMin
=
prim
->
GetAttr
(
"filter_minq"
);
...
...
@@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
}
vecQuantParam
->
emplace_back
(
quants
);
vec
Input
QuantParam
->
emplace_back
(
quants
);
}
quants
.
clear
();
...
...
@@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant
::
CalQuantizationParams
(
&
quantParam
,
quantParam
.
min
,
quantParam
.
max
,
narrowRangeQuantParam
,
numbitsRangeQuantParam
);
quants
.
emplace_back
(
quantParam
);
vecQuantParam
->
emplace_back
(
quants
);
vec
Output
QuantParam
->
emplace_back
(
quants
);
}
}
...
...
@@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
primitive
->
value
.
value
=
attr
.
release
();
MS_ASSERT
(
primitiveTValuePtr
!=
nullptr
);
primitiveTValuePtr
->
SetPrimitiveT
(
primitive
.
release
());
if
(
primitiveTValuePtr
->
GetQuantType
())
{
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecQuantParam
;
PopulaterQuantParam
(
prim
,
&
vecQuantParam
);
primitiveTValuePtr
->
SetInputQuantParam
(
vecQuantParam
);
if
(
primitiveTValuePtr
->
GetQuantType
()
==
schema
::
QuantType_AwareTraining
)
{
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecInputQuantParam
;
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
vecOutputQuantParam
;
PopulaterQuantParam
(
prim
,
&
vecInputQuantParam
,
&
vecOutputQuantParam
);
primitiveTValuePtr
->
SetInputQuantParam
(
vecInputQuantParam
);
primitiveTValuePtr
->
SetOutputQuantParam
(
vecOutputQuantParam
);
}
return
0
;
}
AnfNodePopulaterRegistrar
anfMatmulPopulater
(
"Matmul"
,
new
AnfMatmulPopulater
());
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h
浏览文件 @
406ce735
...
...
@@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater {
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
private:
void
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecQuantParam
);
void
CalQuantParam
(
const
double
&
mean
,
const
double
&
stdDev
,
float
*
mMin
,
float
*
mMax
);
void
PopulaterQuantParam
(
const
PrimitivePtr
&
prim
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecInputQuantParam
,
std
::
vector
<
std
::
vector
<
schema
::
QuantParamT
>>
*
vecOutputQuantParam
);
void
CalQuantParam
(
const
double
&
mean
,
const
double
&
stdDev
,
float
*
mMin
,
float
*
mMax
);
};
}
// namespace mindspore::lite
...
...
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
浏览文件 @
406ce735
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录