Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
da1c32a7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
da1c32a7
编写于
8月 28, 2020
作者:
C
cjh9368
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug for converter_flags
上级
2ef32167
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
33 addition
and
16 deletion
+33
-16
mindspore/lite/tools/converter/converter_flags.cc
mindspore/lite/tools/converter/converter_flags.cc
+18
-8
mindspore/lite/tools/converter/converter_flags.h
mindspore/lite/tools/converter/converter_flags.h
+2
-2
mindspore/lite/tools/converter/graphdef_transform.cc
mindspore/lite/tools/converter/graphdef_transform.cc
+1
-0
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
...ools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
+4
-1
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h
...tools/converter/legacy_optimizer/graph/dtype_trans_pass.h
+3
-0
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
+1
-1
mindspore/lite/tools/converter/quantizer/aware_quantizer.h
mindspore/lite/tools/converter/quantizer/aware_quantizer.h
+2
-2
mindspore/lite/tools/converter/quantizer/quantize_util.cc
mindspore/lite/tools/converter/quantizer/quantize_util.cc
+2
-2
未找到文件。
mindspore/lite/tools/converter/converter_flags.cc
浏览文件 @
da1c32a7
...
...
@@ -29,12 +29,12 @@ Flags::Flags() {
AddFlag
(
&
Flags
::
outputFile
,
"outputFile"
,
"Output model file path. Will add .ms automatically"
,
""
);
AddFlag
(
&
Flags
::
weightFile
,
"weightFile"
,
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel"
,
""
);
AddFlag
(
&
Flags
::
inferenceType
,
"inferenceType"
,
"Real data type saved in output file, reserved param, NOT used for now. FLOAT |
FP16 | U
INT8"
,
"FLOAT"
);
AddFlag
(
&
Flags
::
quantTypeIn
,
"quantType"
,
"Quantization Type. AwareTraining |
WeightQuant |
PostTraining"
,
""
);
AddFlag
(
&
Flags
::
inputInferenceTypeIn
,
"inputInferenceType"
,
"Input inference data type. FLOAT |
U
INT8"
,
"FLOAT"
);
AddFlag
(
&
Flags
::
inferenceType
In
,
"inferenceType"
,
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8"
,
"FLOAT"
);
AddFlag
(
&
Flags
::
quantTypeIn
,
"quantType"
,
"Quantization Type. AwareTraining | PostTraining"
,
""
);
AddFlag
(
&
Flags
::
inputInferenceTypeIn
,
"inputInferenceType"
,
"Input inference data type. FLOAT | INT8"
,
"FLOAT"
);
AddFlag
(
&
Flags
::
stdDev
,
"stdDev"
,
"Standard deviation value for aware-quantization"
,
"128"
);
AddFlag
(
&
Flags
::
mean
,
"mean"
,
"Mean value for aware-quantization"
,
"
127
"
);
AddFlag
(
&
Flags
::
mean
,
"mean"
,
"Mean value for aware-quantization"
,
"
-0.5
"
);
AddFlag
(
&
Flags
::
quantSize
,
"quantSize"
,
"Weight quantization size threshold"
,
"0"
);
AddFlag
(
&
Flags
::
configFile
,
"config_file"
,
"Configuration for post-training."
,
""
);
AddFlag
(
&
Flags
::
formatTrans
,
"formatTrans"
,
"whether transform format. true | false"
,
"true"
);
...
...
@@ -77,14 +77,24 @@ int Flags::Init(int argc, const char **argv) {
}
if
(
this
->
inputInferenceTypeIn
==
"FLOAT"
)
{
this
->
inputInferenceType
=
TypeId
::
kNumberTypeFloat
;
}
else
if
(
this
->
inputInferenceTypeIn
==
"UINT8"
)
{
this
->
inputInferenceType
=
TypeId
::
kNumberTypeUInt8
;
}
else
if
(
this
->
inputInferenceTypeIn
==
"INT8"
)
{
this
->
inputInferenceType
=
TypeId
::
kNumberTypeInt8
;
}
else
{
std
::
cerr
<<
"INPUT INVALID: inputInferenceType is invalid: %s"
,
this
->
inputInferenceTypeIn
.
c_str
();
std
::
cerr
<<
"INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8"
,
this
->
inputInferenceTypeIn
.
c_str
();
return
1
;
}
if
(
this
->
inferenceTypeIn
==
"FLOAT"
)
{
this
->
inferenceType
=
TypeId
::
kNumberTypeFloat
;
}
else
if
(
this
->
inferenceTypeIn
==
"INT8"
)
{
this
->
inferenceType
=
TypeId
::
kNumberTypeInt8
;
}
else
{
std
::
cerr
<<
"INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8"
,
this
->
inferenceTypeIn
.
c_str
();
return
1
;
}
if
(
this
->
fmkIn
==
"CAFFE"
)
{
this
->
fmk
=
FmkType_CAFFE
;
}
else
if
(
this
->
fmkIn
==
"MS"
)
{
...
...
mindspore/lite/tools/converter/converter_flags.h
浏览文件 @
da1c32a7
...
...
@@ -63,10 +63,10 @@ class Flags : public virtual mindspore::lite::FlagParser {
// used for quantization
std
::
string
quantTypeIn
;
QuantType
quantType
;
std
::
string
inferenceType
;
std
::
string
inferenceTypeIn
;
TypeId
inferenceType
=
TypeId
::
kNumberTypeFloat
;
// used for parse aware trainning
std
::
string
inputInferenceTypeIn
;
// mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT;
TypeId
inputInferenceType
=
TypeId
::
kNumberTypeFloat
;
std
::
string
stdDev
;
std
::
string
mean
;
...
...
mindspore/lite/tools/converter/graphdef_transform.cc
浏览文件 @
da1c32a7
...
...
@@ -194,6 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return
RET_ERROR
;
}
dTypeTransPass
->
SetInputDataDType
(
ctx
.
inputInferenceType
);
dTypeTransPass
->
SetOutputDataDType
(
ctx
.
inferenceType
);
quantNodeOptimizer
.
AddPass
(
dTypeTransPass
);
quantNodeOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
QuantCastFusionPass
());
quantNodeOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
IsolatedNodeRemovePass
());
...
...
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
浏览文件 @
da1c32a7
...
...
@@ -101,7 +101,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS
DTypeTransPass
::
DoModelOutputDTypeTrans
(
schema
::
MetaGraphT
*
graph
)
{
MS_ASSERT
(
graph
!=
nullptr
);
if
(
in
putDataDType
==
TypeId
::
kNumberTypeInt8
)
{
if
(
out
putDataDType
==
TypeId
::
kNumberTypeInt8
)
{
return
RET_OK
;
}
MS_ASSERT
(
inputDataDType
==
TypeId
::
kNumberTypeFloat
);
...
...
@@ -231,5 +231,8 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
}
void
DTypeTransPass
::
SetInputDataDType
(
TypeId
dataType
)
{
this
->
inputDataDType
=
dataType
;
}
void
DTypeTransPass
::
SetOutputDataDType
(
TypeId
dataType
)
{
this
->
outputDataDType
=
dataType
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h
浏览文件 @
da1c32a7
...
...
@@ -38,6 +38,8 @@ class DTypeTransPass : public GraphPass {
void
SetInputDataDType
(
TypeId
dataType
);
void
SetOutputDataDType
(
TypeId
dataType
);
private:
STATUS
DoModelInputDTypeTrans
(
schema
::
MetaGraphT
*
graph
);
...
...
@@ -51,6 +53,7 @@ class DTypeTransPass : public GraphPass {
private:
size_t
id
;
TypeId
inputDataDType
=
TypeId
::
kNumberTypeFloat
;
TypeId
outputDataDType
=
TypeId
::
kNumberTypeFloat
;
OpDefCopyer
castOpCopyer
=
[](
schema
::
CNodeT
*
inCNode
)
->
std
::
unique_ptr
<
schema
::
CNodeT
>
{
std
::
unique_ptr
<
schema
::
CNodeT
>
newCNode
(
new
(
std
::
nothrow
)
schema
::
CNodeT
);
...
...
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
浏览文件 @
da1c32a7
...
...
@@ -88,7 +88,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
if
(
inputInferType
==
"FLOAT"
)
{
inArr
.
reset
(
new
(
std
::
nothrow
)
InputArray
(
mean
,
stdValue
));
}
else
{
inArr
.
reset
(
new
(
std
::
nothrow
)
InputArray
(
mean
,
stdValue
,
TypeId
::
kNumberType
U
Int8
));
inArr
.
reset
(
new
(
std
::
nothrow
)
InputArray
(
mean
,
stdValue
,
TypeId
::
kNumberTypeInt8
));
}
mInputArray
=
inArr
.
get
();
mInputArray
->
InitQuantParam
();
...
...
mindspore/lite/tools/converter/quantizer/aware_quantizer.h
浏览文件 @
da1c32a7
...
...
@@ -37,8 +37,8 @@ struct InputArray {
InputArray
(
float
mean
,
float
stdDev
,
TypeId
dataType
=
TypeId
::
kNumberTypeFloat
)
{
this
->
dataType
=
dataType
;
constexpr
float
qmin
=
0
;
constexpr
float
qmax
=
255
;
constexpr
float
qmin
=
-
128
;
constexpr
float
qmax
=
127
;
mMin
=
(
qmin
-
mean
)
/
stdDev
;
mMax
=
(
qmax
-
mean
)
/
stdDev
;
}
...
...
mindspore/lite/tools/converter/quantizer/quantize_util.cc
浏览文件 @
da1c32a7
...
...
@@ -246,8 +246,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return
RET_OK
;
}
int
quantMin
=
narrowRange
?
1
:
0
;
int
quantMax
=
(
1
<<
(
unsigned
int
)
numBits
)
-
1
;
int
quantMin
=
narrowRange
?
1
:
0
-
128
;
int
quantMax
=
(
1
<<
(
unsigned
int
)
numBits
)
-
1
-
128
;
auto
quantMinFloat
=
static_cast
<
double
>
(
quantMin
);
auto
quantMaxFloat
=
static_cast
<
double
>
(
quantMax
);
double
scale
=
(
mMax
-
mMin
)
/
(
quantMaxFloat
-
quantMinFloat
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录