Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a095f72e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a095f72e
编写于
8月 24, 2020
作者:
Y
yeyunpeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete GetPrimitiveT in project
上级
e3899c55
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
19 addition
and
39 deletion
+19
-39
mindspore/lite/tools/converter/converter.cc
mindspore/lite/tools/converter/converter.cc
+1
-3
mindspore/lite/tools/converter/quantizer/quantize_util.cc
mindspore/lite/tools/converter/quantizer/quantize_util.cc
+1
-1
mindspore/lite/tools/optimizer/common/gllo_utils.cc
mindspore/lite/tools/optimizer/common/gllo_utils.cc
+2
-2
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
...re/lite/tools/optimizer/fusion/constant_folding_fusion.cc
+3
-29
mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc
mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc
+12
-4
未找到文件。
mindspore/lite/tools/converter/converter.cc
浏览文件 @
a095f72e
...
...
@@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
return
;
}
if
(
primT
->
value
.
type
==
schema
::
PrimitiveType_TupleGetItem
||
primT
->
value
.
type
==
schema
::
PrimitiveType_MakeTuple
||
primT
->
value
.
type
==
schema
::
PrimitiveType_Return
)
{
primT
->
value
.
type
==
schema
::
PrimitiveType_MakeTuple
||
primT
->
value
.
type
==
schema
::
PrimitiveType_Return
)
{
delete
primT
;
primitiveT_value
->
SetPrimitiveT
(
nullptr
);
}
}
return
;
}
MetaGraphT
*
Converter
::
Convert
(
const
converter
::
Flags
*
flag
)
{
// parse the model and weight file to generate inference data structure
...
...
mindspore/lite/tools/converter/quantizer/quantize_util.cc
浏览文件 @
a095f72e
...
...
@@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
return
false
;
}
auto
type
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
type
;
auto
type
=
(
schema
::
PrimitiveType
)
primitiveT_value
->
Type
()
;
MS_LOG
(
INFO
)
<<
"Primitive type: "
<<
type
;
static
const
std
::
vector
<
schema
::
PrimitiveType
>
uint8OpList
=
{
schema
::
PrimitiveType_Nchw2Nhwc
,
schema
::
PrimitiveType_Nhwc2Nchw
,
...
...
mindspore/lite/tools/optimizer/common/gllo_utils.cc
浏览文件 @
a095f72e
...
...
@@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if
(
a
.
m_ptr
->
isa
<
lite
::
PrimitiveC
>
()
&&
b
.
m_ptr
->
isa
<
lite
::
PrimitiveC
>
())
{
auto
a_value_node_ptr
=
a
.
m_ptr
->
cast
<
PrimitiveCPtr
>
();
auto
b_value_node_ptr
=
b
.
m_ptr
->
cast
<
PrimitiveCPtr
>
();
return
a_value_node_ptr
->
GetPrimitiveT
()
->
value
.
type
==
b_value_node_ptr
->
GetPrimitiveT
()
->
value
.
type
;
return
a_value_node_ptr
->
Type
()
==
b_value_node_ptr
->
Type
()
;
}
return
a
==
b
;
...
...
@@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
if
(
utils
::
isa
<
PrimitiveCPtr
>
(
value
))
{
auto
primitive
=
value
->
cast
<
PrimitiveCPtr
>
();
MS_ASSERT
(
primitive
!=
nullptr
);
return
primitive
->
GetPrimitiveT
()
->
value
.
type
;
return
(
schema
::
PrimitiveType
)
primitive
->
Type
()
;
}
else
if
(
utils
::
isa
<
Primitive
>
(
value
))
{
auto
primitive
=
value
->
cast
<
PrimitivePtr
>
();
MS_ASSERT
(
primitive
!=
nullptr
);
...
...
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
浏览文件 @
a095f72e
...
...
@@ -73,26 +73,6 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
}
return
input_tensors
;
}
schema
::
Primitive
*
PackPrimitiveT
(
const
CNodePtr
&
cnode
)
{
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveC
>>
(
cnode
->
input
(
0
));
if
(
primitiveT_value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"PrimitiveT_value is nullptr"
;
return
nullptr
;
}
auto
*
lite_primitive
=
primitiveT_value
->
GetPrimitiveT
();
if
(
lite_primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Primitive in primitiveT_value is nullptr"
;
return
nullptr
;
}
flatbuffers
::
FlatBufferBuilder
builder
(
1024
);
auto
offset
=
schema
::
Primitive
::
Pack
(
builder
,
lite_primitive
);
builder
.
Finish
(
offset
);
auto
buf
=
builder
.
GetBufferPointer
();
auto
primitive
=
flatbuffers
::
GetRoot
<
schema
::
Primitive
>
(
buf
);
return
const_cast
<
schema
::
Primitive
*>
(
primitive
);
}
const
ParameterPtr
CreateNewParamter
(
const
FuncGraphPtr
&
func_graph
,
Tensor
*
tensor
)
{
auto
parameter
=
func_graph
->
add_parameter
();
std
::
vector
<
int
>
shape
;
...
...
@@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
}
MS_LOG
(
INFO
)
<<
"Begin fold node:"
<<
input_node
->
fullname_with_scope
();
auto
output_nums
=
GetOutputTensorNum
(
input_cnode
);
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveC
>>
(
input_cnode
->
input
(
0
));
std
::
vector
<
Tensor
*>
output_tensors
{
output_nums
,
new
Tensor
()};
auto
scheam_primitive
=
PackPrimitiveT
(
input_cnode
);
auto
lite_primitive
=
mindspore
::
lite
::
PrimitiveC
::
UnPackFromSchemaPrimitive
(
scheam_primitive
);
if
(
lite_primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"constant_folding schedule node lite primitive nullptr"
;
FreeInputTensor
(
&
input_tensors
);
return
nullptr
;
}
lite_primitive
->
InferShape
(
input_tensors
,
output_tensors
);
auto
lite_kernel
=
GetLiteKernel
(
input_tensors
,
output_tensors
,
lite_primitive
);
primitiveT_value
->
InferShape
(
input_tensors
,
output_tensors
);
auto
lite_kernel
=
GetLiteKernel
(
input_tensors
,
output_tensors
,
primitiveT_value
.
get
());
if
(
lite_kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"constant_folding schedule node lite kernel nullptr"
;
FreeInputTensor
(
&
input_tensors
);
...
...
mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc
浏览文件 @
a095f72e
...
...
@@ -22,6 +22,8 @@
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "src/ops/batch_norm.h"
#include "src/ops/fused_batchnorm.h"
namespace
mindspore
::
opt
{
namespace
{
...
...
@@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
auto
bn_mean_var
=
std
::
make_shared
<
CondVar
>
(
IsParamNode
);
auto
bn_variable_var
=
std
::
make_shared
<
CondVar
>
(
IsParamNode
);
auto
bn_other_var
=
std
::
make_shared
<
SeqVar
>
();
return
VectorRef
({
bn_var
,
conv_var
,
bn_mean_var
,
bn_variable_var
,
bn_other_var
});
;
return
VectorRef
({
bn_var
,
conv_var
,
bn_mean_var
,
bn_variable_var
,
bn_other_var
});
}
// BatchNorm weight Tensor definition:
// caffe
...
...
@@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
// estimated_mean --2
// estimated_variance --3
const
void
ConvBatchNormFusion
::
InitTransParam
(
const
CNodePtr
&
bn_node
,
int
kernel_num
,
float
*
trans_scale
,
float
*
trans_bias
)
const
{
float
*
trans_bias
)
const
{
MS_ASSERT
(
bn_node
!=
nullptr
);
AnfNodePtr
bn_mean_node
=
nullptr
;
AnfNodePtr
bn_variance_node
=
nullptr
;
...
...
@@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
bn_variance_node
=
bn_node
->
input
(
kCaffeBNVarIndex
);
CheckIfNodeIsParam
(
bn_mean_node
);
CheckIfNodeIsParam
(
bn_variance_node
);
eps
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsBatchNorm
()
->
epsilon
;
MS_ASSERT
(
utils
::
isa
<
std
::
shared_ptr
<
mindspore
::
lite
::
BatchNorm
>>
(
primitiveT_value
));
auto
primc
=
utils
::
cast
<
std
::
shared_ptr
<
mindspore
::
lite
::
BatchNorm
>>
(
primitiveT_value
);
MS_ASSERT
(
primc
!=
nullptr
);
eps
=
primc
->
GetEpsilon
();
}
else
if
(
GetCNodeType
(
bn_node
)
==
schema
::
PrimitiveType_FusedBatchNorm
)
{
bn_scale_node
=
bn_node
->
input
(
kTFBNScaleIndex
);
bn_bias_node
=
bn_node
->
input
(
kTFBNBiasIndex
);
bn_mean_node
=
bn_node
->
input
(
kTFBNMeanIndex
);
bn_variance_node
=
bn_node
->
input
(
kTFBNVarIndex
);
eps
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsFusedBatchNorm
()
->
epsilon
;
MS_ASSERT
(
utils
::
isa
<
std
::
shared_ptr
<
mindspore
::
lite
::
FusedBatchNorm
>>
(
primitiveT_value
));
auto
primc
=
utils
::
cast
<
std
::
shared_ptr
<
mindspore
::
lite
::
FusedBatchNorm
>>
(
primitiveT_value
);
MS_ASSERT
(
primc
!=
nullptr
);
eps
=
primc
->
GetEpsilon
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"not caffe or tf batchnorm op."
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录