Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6d500c86
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d500c86
编写于
9月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5480 add infershape and trans op optimize
Merge pull request !5480 from zhengjun10/master
上级
9aa322bd
63343d10
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
289 addition
and
15 deletion
+289
-15
mindspore/lite/tools/common/node_util.cc
mindspore/lite/tools/common/node_util.cc
+6
-0
mindspore/lite/tools/common/node_util.h
mindspore/lite/tools/common/node_util.h
+2
-0
mindspore/lite/tools/converter/graphdef_transform.cc
mindspore/lite/tools/converter/graphdef_transform.cc
+8
-3
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
...ite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
+3
-1
mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc
...tools/converter/legacy_optimizer/graph/infershape_pass.cc
+123
-0
mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h
.../tools/converter/legacy_optimizer/graph/infershape_pass.h
+40
-0
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc
...verter/legacy_optimizer/graph/trans_format_insert_pass.cc
+13
-6
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h
...nverter/legacy_optimizer/graph/trans_format_insert_pass.h
+3
-3
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc
...verter/legacy_optimizer/graph/trans_format_remove_pass.cc
+49
-0
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h
...nverter/legacy_optimizer/graph/trans_format_remove_pass.h
+40
-0
mindspore/lite/tools/converter/optimizer.cc
mindspore/lite/tools/converter/optimizer.cc
+2
-2
未找到文件。
mindspore/lite/tools/common/node_util.cc
浏览文件 @
6d500c86
...
@@ -52,6 +52,12 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
...
@@ -52,6 +52,12 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
schema
::
PrimitiveType_Squeeze
,
schema
::
PrimitiveType_Sub
,
schema
::
PrimitiveType_Squeeze
,
schema
::
PrimitiveType_Sub
,
schema
::
PrimitiveType_TopK
,
schema
::
PrimitiveType_Unsqueeze
};
schema
::
PrimitiveType_TopK
,
schema
::
PrimitiveType_Unsqueeze
};
static
const
std
::
vector
<
schema
::
PrimitiveType
>
needInsertOpList
=
{
schema
::
PrimitiveType_Eltwise
,
schema
::
PrimitiveType_Activation
,
schema
::
PrimitiveType_Concat
,
schema
::
PrimitiveType_Power
};
std
::
vector
<
schema
::
PrimitiveType
>
GetInsertOpList
()
{
return
needInsertOpList
;
}
std
::
vector
<
schema
::
PrimitiveType
>
Getfp32FullOpList
()
{
return
fp32FullOpList
;
}
std
::
vector
<
schema
::
PrimitiveType
>
Getfp32FullOpList
()
{
return
fp32FullOpList
;
}
std
::
vector
<
schema
::
PrimitiveType
>
GetNhwcOpList
()
{
return
nhwcOpList
;
}
std
::
vector
<
schema
::
PrimitiveType
>
GetNhwcOpList
()
{
return
nhwcOpList
;
}
...
...
mindspore/lite/tools/common/node_util.h
浏览文件 @
6d500c86
...
@@ -30,6 +30,8 @@ namespace lite {
...
@@ -30,6 +30,8 @@ namespace lite {
using
STATUS
=
int
;
using
STATUS
=
int
;
STATUS
BroadCastQuantParam
(
schema
::
MetaGraphT
*
graphT
,
const
std
::
unique_ptr
<
schema
::
CNodeT
>
&
node
);
STATUS
BroadCastQuantParam
(
schema
::
MetaGraphT
*
graphT
,
const
std
::
unique_ptr
<
schema
::
CNodeT
>
&
node
);
std
::
vector
<
schema
::
PrimitiveType
>
GetInsertOpList
();
std
::
vector
<
schema
::
PrimitiveType
>
GetNhwcOpList
();
std
::
vector
<
schema
::
PrimitiveType
>
GetNhwcOpList
();
std
::
vector
<
schema
::
PrimitiveType
>
Getfp32FullOpList
();
std
::
vector
<
schema
::
PrimitiveType
>
Getfp32FullOpList
();
...
...
mindspore/lite/tools/converter/graphdef_transform.cc
浏览文件 @
6d500c86
...
@@ -25,11 +25,13 @@
...
@@ -25,11 +25,13 @@
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/
eltwise_format_trans
_pass.h"
#include "tools/converter/legacy_optimizer/graph/
trans_format_insert
_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
...
@@ -145,11 +147,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
...
@@ -145,11 +147,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransPass
->
SetQuantType
(
ctx
.
quantType
);
formatTransPass
->
SetQuantType
(
ctx
.
quantType
);
formatTransPass
->
SetFmk
(
ctx
.
fmk
);
formatTransPass
->
SetFmk
(
ctx
.
fmk
);
formatTransOptimizer
.
AddPass
(
formatTransPass
);
formatTransOptimizer
.
AddPass
(
formatTransPass
);
formatTransOptimizer
.
AddPass
(
new
EltwiseFormatTransPass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
TopologicalSortPass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
InferShapePass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
TransOpRemovePass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
TransOpInsertPass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
FormatTransFusionPass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
FormatTransFusionPass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
IsolatedNodeRemovePass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
IsolatedNodeRemovePass
());
status
=
formatTransOptimizer
.
Run
(
graphDefT
);
status
=
formatTransOptimizer
.
Run
(
graphDefT
);
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
)
{
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
&&
status
!=
RET_INFER_ERR
)
{
MS_LOG
(
ERROR
)
<<
"Run formatTransOptimizer graphPasses Failed"
;
MS_LOG
(
ERROR
)
<<
"Run formatTransOptimizer graphPasses Failed"
;
return
status
;
return
status
;
}
}
...
...
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
浏览文件 @
6d500c86
add_library
(
graph_pass_mid OBJECT
add_library
(
graph_pass_mid OBJECT
${
CMAKE_CURRENT_SOURCE_DIR
}
/format_trans_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/format_trans_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/
eltwise_format_trans
_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/
trans_format_insert
_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/dtype_trans_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/dtype_trans_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/isolated_node_remove_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/isolated_node_remove_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/model_input_format_preprocess_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/model_input_format_preprocess_pass.cc
...
@@ -9,4 +9,6 @@ add_library(graph_pass_mid OBJECT
...
@@ -9,4 +9,6 @@ add_library(graph_pass_mid OBJECT
${
CMAKE_CURRENT_SOURCE_DIR
}
/topological_sort_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/topological_sort_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/unused_node_remove_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/unused_node_remove_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/batchnorm_convert_scale_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/batchnorm_convert_scale_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/trans_format_remove_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/infershape_pass.cc
)
)
mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc
0 → 100644
浏览文件 @
6d500c86
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
#include <vector>
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "src/ir/tensor.h"
#include "src/ops/primitive_c.h"
using
mindspore
::
lite
::
tensor
::
Tensor
;
using
mindspore
::
lite
::
PrimitiveC
;
namespace
mindspore
{
namespace
lite
{
namespace
{
std
::
vector
<
tensor
::
Tensor
*>
ConvertTensorToLiteTensor
(
MetaGraphT
*
graph
,
const
std
::
vector
<
uint32_t
>
&
tensor_indexs
,
const
schema
::
PrimitiveType
node_type
)
{
std
::
vector
<
tensor
::
Tensor
*>
lite_tensors
;
for
(
size_t
i
=
0
;
i
<
tensor_indexs
.
size
();
i
++
)
{
auto
&
tensorT
=
graph
->
allTensors
.
at
(
tensor_indexs
[
i
]);
auto
tensor_shape
=
tensorT
->
dims
;
auto
lite_tensor
=
new
(
std
::
nothrow
)
tensor
::
Tensor
(
TypeId
(
tensorT
->
dataType
),
tensor_shape
,
tensorT
->
format
,
tensorT
->
nodeType
);
if
(
lite_tensor
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"lite tensor is nullptr"
;
return
std
::
vector
<
tensor
::
Tensor
*>
();
}
// reshape op must get tensor data to infershape
if
(
node_type
==
schema
::
PrimitiveType_Reshape
&&
i
==
1
&&
tensorT
->
nodeType
==
NodeType_ValueNode
)
{
auto
lite_tensor_size
=
tensorT
->
data
.
size
()
*
sizeof
(
uint8_t
);
// when tensorT as param input
if
(
lite_tensor_size
==
0
)
{
delete
lite_tensor
;
return
std
::
vector
<
tensor
::
Tensor
*>
();
}
auto
tensor_data
=
new
(
std
::
nothrow
)
char
[
lite_tensor_size
/
sizeof
(
char
)];
if
(
tensor_data
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"tensor_data is nullptr"
;
delete
lite_tensor
;
return
std
::
vector
<
tensor
::
Tensor
*>
();
}
auto
ret
=
memcpy_s
(
tensor_data
,
lite_tensor_size
,
tensorT
->
data
.
data
(),
lite_tensor_size
);
if
(
ret
!=
EOK
)
{
delete
lite_tensor
;
delete
[]
tensor_data
;
MS_LOG
(
ERROR
)
<<
"memcpy error: "
<<
ret
;
return
std
::
vector
<
tensor
::
Tensor
*>
();
}
lite_tensor
->
SetData
(
tensor_data
);
lite_tensors
.
emplace_back
(
lite_tensor
);
continue
;
}
lite_tensors
.
emplace_back
(
lite_tensor
);
}
return
lite_tensors
;
}
}
// namespace
STATUS
InferShapePass
::
Run
(
MetaGraphT
*
graph
)
{
MS_ASSERT
(
graph
!=
nullptr
);
for
(
auto
iter
=
graph
->
nodes
.
begin
();
iter
!=
graph
->
nodes
.
end
();
iter
++
)
{
auto
&
node
=
*
iter
;
auto
input_tensors
=
ConvertTensorToLiteTensor
(
graph
,
node
->
inputIndex
,
node
->
primitive
->
value
.
type
);
if
(
input_tensors
.
empty
()
||
input_tensors
.
size
()
!=
node
->
inputIndex
.
size
())
{
MS_LOG
(
ERROR
)
<<
"convert input lite tensor error"
;
return
RET_INFER_ERR
;
}
auto
output_tensors
=
ConvertTensorToLiteTensor
(
graph
,
node
->
outputIndex
,
node
->
primitive
->
value
.
type
);
if
(
output_tensors
.
empty
()
||
output_tensors
.
size
()
!=
node
->
outputIndex
.
size
())
{
MS_LOG
(
ERROR
)
<<
"convert output lite tensor error"
;
return
RET_INFER_ERR
;
}
std
::
unique_ptr
<
PrimitiveT
>
primitiveT
(
new
(
std
::
nothrow
)
PrimitiveT
(
*
node
->
primitive
));
if
(
primitiveT
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"copy primitiveT error"
;
return
RET_ERROR
;
}
auto
primitiveC
=
std
::
shared_ptr
<
PrimitiveC
>
(
PrimitiveC
::
UnPackFromSchemaPrimitiveT
(
primitiveT
.
release
()));
if
(
primitiveC
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"unpack primitiveT error"
;
return
RET_ERROR
;
}
auto
ret
=
primitiveC
->
InferShape
(
input_tensors
,
output_tensors
);
if
(
ret
==
RET_INFER_INVALID
)
{
MS_LOG
(
INFO
)
<<
"InferShape shouldn't be done before runtime, name: "
<<
node
->
name
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
node
->
primitive
->
value
.
type
)
<<
"flag set to false."
;
}
else
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
WARNING
)
<<
"InferShape failed, name: "
<<
node
->
name
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
node
->
primitive
->
value
.
type
);
return
RET_INFER_ERR
;
}
// copy output shape to tensorT
for
(
size_t
i
=
0
;
i
<
output_tensors
.
size
();
i
++
)
{
auto
output_dims
=
output_tensors
[
i
]
->
shape
();
auto
&
output_tensor
=
graph
->
allTensors
.
at
(
node
->
outputIndex
[
i
]);
output_tensor
->
dims
.
swap
(
output_dims
);
output_tensor
->
format
=
output_tensors
[
i
]
->
GetFormat
();
output_tensor
->
dataType
=
output_tensors
[
i
]
->
data_type
();
}
for
(
auto
input_tensor
:
input_tensors
)
{
delete
input_tensor
;
}
for
(
auto
output_tensor
:
output_tensors
)
{
delete
output_tensor
;
}
}
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h
0 → 100644
浏览文件 @
6d500c86
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_INFERSHAPE_PASS_H
#define MINDSPORE_PREDICT_INFERSHAPE_PASS_H
#include <unordered_map>
#include <memory>
#include <string>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/converter/optimizer.h"
using
mindspore
::
schema
::
TensorT
;
namespace
mindspore
{
namespace
lite
{
class
InferShapePass
:
public
GraphPass
{
public:
InferShapePass
()
=
default
;
~
InferShapePass
()
=
default
;
STATUS
Run
(
MetaGraphT
*
graph
)
override
;
};
}
// namespace lite
}
// namespace mindspore
#endif // MINDSPORE_PREDICT_INFERSHAPE_PASS_H
mindspore/lite/tools/converter/legacy_optimizer/graph/
eltwise_format_trans
_pass.cc
→
mindspore/lite/tools/converter/legacy_optimizer/graph/
trans_format_insert
_pass.cc
浏览文件 @
6d500c86
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include <string>
#include <string>
#include <memory>
#include <memory>
#include <utility>
#include <utility>
#include "tools/converter/legacy_optimizer/graph/
eltwise_format_trans
_pass.h"
#include "tools/converter/legacy_optimizer/graph/
trans_format_insert
_pass.h"
#include "tools/common/converter_op_utils.h"
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "tools/common/node_util.h"
#include "utils/log_adapter.h"
#include "utils/log_adapter.h"
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
lite
{
namespace
lite
{
bool
EltwiseFormatTrans
Pass
::
CanFusion
(
schema
::
MetaGraphT
*
graph
,
const
std
::
unique_ptr
<
CNodeT
>
&
node
)
{
bool
TransOpInsert
Pass
::
CanFusion
(
schema
::
MetaGraphT
*
graph
,
const
std
::
unique_ptr
<
CNodeT
>
&
node
)
{
auto
input_node_indexes
=
GetInputNodeIdx
(
*
graph
,
*
node
);
auto
input_node_indexes
=
GetInputNodeIdx
(
*
graph
,
*
node
);
pre_type_
=
schema
::
PrimitiveType_NONE
;
pre_type_
=
schema
::
PrimitiveType_NONE
;
size_t
has_trans_count
=
0
;
size_t
has_trans_count
=
0
;
...
@@ -95,7 +95,7 @@ bool EltwiseFormatTransPass::CanFusion(schema::MetaGraphT *graph, const std::uni
...
@@ -95,7 +95,7 @@ bool EltwiseFormatTransPass::CanFusion(schema::MetaGraphT *graph, const std::uni
return
can_fusion
;
return
can_fusion
;
}
}
STATUS
EltwiseFormatTrans
Pass
::
FindOutTransType
()
{
STATUS
TransOpInsert
Pass
::
FindOutTransType
()
{
pre_insert_trans_type_
=
kNHWC2NCHW
;
pre_insert_trans_type_
=
kNHWC2NCHW
;
post_insert_trans_type_
=
kNHWC2NCHW
;
post_insert_trans_type_
=
kNHWC2NCHW
;
if
(
pre_type_
==
PrimitiveType_NONE
&&
post_type_
!=
PrimitiveType_NONE
)
{
if
(
pre_type_
==
PrimitiveType_NONE
&&
post_type_
!=
PrimitiveType_NONE
)
{
...
@@ -117,12 +117,12 @@ STATUS EltwiseFormatTransPass::FindOutTransType() {
...
@@ -117,12 +117,12 @@ STATUS EltwiseFormatTransPass::FindOutTransType() {
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
EltwiseFormatTrans
Pass
::
Run
(
schema
::
MetaGraphT
*
graph
)
{
STATUS
TransOpInsert
Pass
::
Run
(
schema
::
MetaGraphT
*
graph
)
{
MS_ASSERT
(
graph
!=
nullptr
);
MS_ASSERT
(
graph
!=
nullptr
);
for
(
auto
iter
=
graph
->
nodes
.
begin
();
iter
!=
graph
->
nodes
.
end
();
iter
++
)
{
for
(
auto
iter
=
graph
->
nodes
.
begin
();
iter
!=
graph
->
nodes
.
end
();
iter
++
)
{
auto
&
node
=
*
iter
;
auto
&
node
=
*
iter
;
auto
type
=
node
->
primitive
->
value
.
type
;
auto
type
=
node
->
primitive
->
value
.
type
;
if
(
type
!=
PrimitiveType_Eltwise
&&
type
!=
PrimitiveType_Activation
)
{
if
(
!
IsContain
(
GetInsertOpList
(),
type
)
)
{
continue
;
continue
;
}
}
auto
node_name
=
node
->
name
;
auto
node_name
=
node
->
name
;
...
@@ -134,7 +134,14 @@ STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
...
@@ -134,7 +134,14 @@ STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
MS_LOG
(
ERROR
)
<<
"FindOutTransType error"
;
MS_LOG
(
ERROR
)
<<
"FindOutTransType error"
;
return
ret
;
return
ret
;
}
}
// 4 dims means infershape success,can delete
if
(
type
==
PrimitiveType_Concat
)
{
if
(
graph
->
allTensors
.
at
(
node
->
inputIndex
[
0
])
->
dims
.
size
()
==
4
)
{
node
->
primitive
->
value
.
AsConcat
()
->
axis
=
-
1
;
}
else
{
continue
;
}
}
STATUS
status
=
RET_OK
;
STATUS
status
=
RET_OK
;
auto
input_tensor_size
=
(
*
iter
)
->
inputIndex
.
size
();
auto
input_tensor_size
=
(
*
iter
)
->
inputIndex
.
size
();
for
(
size_t
i
=
0
;
i
<
input_tensor_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
input_tensor_size
;
i
++
)
{
...
...
mindspore/lite/tools/converter/legacy_optimizer/graph/
eltwise_format_trans
_pass.h
→
mindspore/lite/tools/converter/legacy_optimizer/graph/
trans_format_insert
_pass.h
浏览文件 @
6d500c86
...
@@ -24,11 +24,11 @@
...
@@ -24,11 +24,11 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
lite
{
namespace
lite
{
class
EltwiseFormatTrans
Pass
:
public
FormatTransPass
{
class
TransOpInsert
Pass
:
public
FormatTransPass
{
public:
public:
EltwiseFormatTrans
Pass
()
:
FormatTransPass
()
{}
TransOpInsert
Pass
()
:
FormatTransPass
()
{}
~
EltwiseFormatTrans
Pass
()
override
=
default
;
~
TransOpInsert
Pass
()
override
=
default
;
STATUS
Run
(
schema
::
MetaGraphT
*
graph
)
override
;
STATUS
Run
(
schema
::
MetaGraphT
*
graph
)
override
;
...
...
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc
0 → 100644
浏览文件 @
6d500c86
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h"
#include <vector>
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "tools/common/graph_util.h"
#include "src/ir/tensor.h"
#include "src/ops/primitive_c.h"
using
mindspore
::
lite
::
tensor
::
Tensor
;
using
mindspore
::
lite
::
PrimitiveC
;
namespace
mindspore
{
namespace
lite
{
STATUS
TransOpRemovePass
::
Run
(
MetaGraphT
*
graph
)
{
MS_ASSERT
(
graph
!=
nullptr
);
for
(
auto
iter
=
graph
->
nodes
.
begin
();
iter
!=
graph
->
nodes
.
end
();
iter
++
)
{
auto
&
node
=
*
iter
;
auto
type
=
node
->
primitive
->
value
.
type
;
if
(
type
==
schema
::
PrimitiveType_Nchw2Nhwc
||
type
==
schema
::
PrimitiveType_Nhwc2Nchw
)
{
auto
&
input_tensor
=
graph
->
allTensors
.
at
(
node
->
inputIndex
.
at
(
0
));
// less than 4 dims can delete
if
(
!
input_tensor
->
dims
.
empty
()
&&
input_tensor
->
dims
.
size
()
<
4
)
{
auto
status
=
IsolateOneWayNode
(
graph
,
node
.
get
(),
true
);
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"IsolateOneWayNode failed, node: "
<<
node
->
name
.
c_str
()
<<
", error: "
<<
status
;
return
status
;
}
}
}
}
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h
0 → 100644
浏览文件 @
6d500c86
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H
#define MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H
#include <unordered_map>
#include <memory>
#include <string>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/converter/optimizer.h"
using
mindspore
::
schema
::
TensorT
;
namespace
mindspore
{
namespace
lite
{
class
TransOpRemovePass
:
public
GraphPass
{
public:
TransOpRemovePass
()
=
default
;
~
TransOpRemovePass
()
=
default
;
STATUS
Run
(
MetaGraphT
*
graph
)
override
;
};
}
// namespace lite
}
// namespace mindspore
#endif // MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H
mindspore/lite/tools/converter/optimizer.cc
浏览文件 @
6d500c86
...
@@ -52,7 +52,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
...
@@ -52,7 +52,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
for
(
auto
&
opDef
:
graphDefT
->
nodes
)
{
for
(
auto
&
opDef
:
graphDefT
->
nodes
)
{
for
(
auto
pass
:
this
->
nodePasses
)
{
for
(
auto
pass
:
this
->
nodePasses
)
{
status
=
pass
->
Run
(
new
GraphNode
(
graphDefT
,
opDef
.
get
()));
status
=
pass
->
Run
(
new
GraphNode
(
graphDefT
,
opDef
.
get
()));
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
)
{
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
&&
status
!=
RET_INFER_ERR
)
{
MS_LOG
(
ERROR
)
<<
"Run NodePass failed"
;
MS_LOG
(
ERROR
)
<<
"Run NodePass failed"
;
return
status
;
return
status
;
}
else
{
}
else
{
...
@@ -65,7 +65,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
...
@@ -65,7 +65,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
for
(
auto
pass
:
this
->
graphPasses
)
{
for
(
auto
pass
:
this
->
graphPasses
)
{
status
=
pass
->
Run
(
graphDefT
);
status
=
pass
->
Run
(
graphDefT
);
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
)
{
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
&&
status
!=
RET_INFER_ERR
)
{
MS_LOG
(
ERROR
)
<<
"Run GraphPass failed"
;
MS_LOG
(
ERROR
)
<<
"Run GraphPass failed"
;
return
status
;
return
status
;
}
else
{
}
else
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录