Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1237dfa6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1237dfa6
编写于
4月 16, 2019
作者:
Z
Zhaolong Xing
提交者:
GitHub
4月 16, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16885 from NHZlX/fix_anakin_subgraph_shufflenet
Fix anakin subgraph shufflenet
上级
5d48e9cc
e4726a06
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
264 addition
and
2 deletion
+264
-2
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+31
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+15
-0
paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
+93
-0
paddle/fluid/framework/ir/shuffle_channel_detect_pass.h
paddle/fluid/framework/ir/shuffle_channel_detect_pass.h
+34
-0
paddle/fluid/inference/anakin/convert/CMakeLists.txt
paddle/fluid/inference/anakin/convert/CMakeLists.txt
+2
-2
paddle/fluid/inference/anakin/convert/shuffle_channel.cc
paddle/fluid/inference/anakin/convert/shuffle_channel.cc
+47
-0
paddle/fluid/inference/anakin/convert/shuffle_channel.h
paddle/fluid/inference/anakin/convert/shuffle_channel.h
+38
-0
paddle/fluid/inference/anakin/op_teller.cc
paddle/fluid/inference/anakin/op_teller.cc
+1
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
1237dfa6
...
...
@@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base)
pass_library
(
runtime_context_cache_pass base
)
pass_library
(
quant_conv2d_dequant_fuse_pass inference
)
pass_library
(
fillconstant_elementwisemul_fuse inference
)
pass_library
(
shuffle_channel_detect_pass inference
)
if
(
ANAKIN_FOUND
)
pass_library
(
simplify_anakin_priorbox_detection_out_pass inference
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
1237dfa6
...
...
@@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
}
}
void
patterns
::
ShuffleChannelPattern
::
operator
()(
PDNode
*
reshape1_in
)
{
auto
reshape1_op
=
pattern
->
NewNode
(
reshape1_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
reshape1_out
=
pattern
->
NewNode
(
reshape1_out_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
assert_is_op_input
(
"transpose2"
)
->
AsIntermediate
();
auto
transpose_op
=
pattern
->
NewNode
(
transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
transpose_out
=
pattern
->
NewNode
(
transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
"reshape2"
)
->
AsIntermediate
();
auto
reshape2_op
=
pattern
->
NewNode
(
reshape2_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
reshape2_out
=
pattern
->
NewNode
(
reshape2_out_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
AsOutput
();
reshape1_op
->
LinksFrom
({
reshape1_in
});
reshape1_out
->
LinksFrom
({
reshape1_op
});
transpose_op
->
LinksFrom
({
reshape1_out
});
transpose_out
->
LinksFrom
({
transpose_op
});
reshape2_op
->
LinksFrom
({
transpose_out
});
reshape2_out
->
LinksFrom
({
reshape2_op
});
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
1237dfa6
...
...
@@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
}
};
struct
ShuffleChannelPattern
:
public
PatternBase
{
ShuffleChannelPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"shufflechannel_pattern"
)
{}
void
operator
()(
PDNode
*
reshape1_in
);
PATTERN_DECL_NODE
(
reshape1_op
);
PATTERN_DECL_NODE
(
reshape1_out
);
PATTERN_DECL_NODE
(
transpose_op
);
PATTERN_DECL_NODE
(
transpose_out
);
PATTERN_DECL_NODE
(
reshape2_op
);
PATTERN_DECL_NODE
(
reshape2_out
);
};
}
// namespace patterns
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
0 → 100644
浏览文件 @
1237dfa6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose_op); \
GET_IR_NODE(transpose_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
void
ShuffleChannelDetectPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"shufflechannel_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
assert_is_op_input
(
"reshape2"
,
"X"
)
->
AsInput
();
patterns
::
ShuffleChannelPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
auto
*
input_node
=
subgraph
.
at
(
x
);
auto
reshape1_desc
=
reshape1_op
->
Op
();
auto
reshape2_desc
=
reshape2_op
->
Op
();
std
::
string
input_name
=
input_node
->
Name
();
std
::
string
output_name
=
reshape2_out
->
Name
();
auto
reshape1_shape
=
boost
::
get
<
std
::
vector
<
int
>>
(
reshape1_desc
->
GetAttr
(
"shape"
));
auto
reshape2_shape
=
boost
::
get
<
std
::
vector
<
int
>>
(
reshape2_desc
->
GetAttr
(
"shape"
));
int
i_c
=
reshape1_shape
[
2
];
int
o_c
=
reshape2_shape
[
1
];
int
group
=
o_c
/
i_c
;
framework
::
OpDesc
new_op_desc
;
new_op_desc
.
SetType
(
"shuffle_channel"
);
new_op_desc
.
SetInput
(
"X"
,
{
input_name
});
new_op_desc
.
SetOutput
(
"Out"
,
{
output_name
});
new_op_desc
.
SetAttr
(
"group"
,
group
);
new_op_desc
.
Flush
();
// Create a new node for the fused op.
auto
*
new_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
IR_NODE_LINK_TO
(
input_node
,
new_op
);
IR_NODE_LINK_TO
(
new_op
,
reshape2_out
);
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
{
reshape1_op
,
reshape1_out
,
transpose_op
,
transpose_out
,
reshape2_op
});
};
gpd
(
graph
,
handler
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
shuffle_channel_detect_pass
,
paddle
::
framework
::
ir
::
ShuffleChannelDetectPass
);
paddle/fluid/framework/ir/shuffle_channel_detect_pass.h
0 → 100644
浏览文件 @
1237dfa6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
ShuffleChannelDetectPass
:
public
FusePassBase
{
public:
virtual
~
ShuffleChannelDetectPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/anakin/convert/CMakeLists.txt
浏览文件 @
1237dfa6
...
...
@@ -2,8 +2,8 @@ cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc
roi_align.cc
helper.cc DEPS anakin_engine framework_proto scope op_registry
gtest
)
roi_align.cc
shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto
scope op_registry
gtest
)
cc_test
(
test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL
)
cc_test
(
test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL
)
...
...
paddle/fluid/inference/anakin/convert/shuffle_channel.cc
0 → 100644
浏览文件 @
1237dfa6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/shuffle_channel.h"
#include <algorithm>
#include <string>
#include <vector>
using
anakin
::
PTuple
;
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
template
<
typename
TargetT
,
::
anakin
::
Precision
PrecisionT
>
void
ShuffleChannelOpConverter
<
TargetT
,
PrecisionT
>::
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
output
=
op_desc
.
Output
(
"Out"
).
front
();
auto
op_name
=
op_desc
.
Type
()
+
":"
+
op_desc
.
Output
(
"Out"
).
front
();
this
->
engine_
->
AddOp
(
op_name
,
"ShuffleChannel"
,
{
input
},
{
output
});
auto
group
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"group"
));
this
->
engine_
->
AddOpAttr
(
op_name
,
"group"
,
group
);
}
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
REGISTER_ANAKIN_OP_CONVERTER
(
shuffle_channel
,
ShuffleChannelOpConverter
);
paddle/fluid/inference/anakin/convert/shuffle_channel.h
0 → 100644
浏览文件 @
1237dfa6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
template
<
typename
TargetT
,
::
anakin
::
Precision
PrecisionT
>
class
ShuffleChannelOpConverter
:
public
AnakinOpConverter
<
TargetT
,
PrecisionT
>
{
public:
ShuffleChannelOpConverter
()
=
default
;
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
;
virtual
~
ShuffleChannelOpConverter
()
{}
};
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/anakin/op_teller.cc
浏览文件 @
1237dfa6
...
...
@@ -48,6 +48,7 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set
.
insert
(
"affine_channel"
);
teller_set
.
insert
(
"relu6"
);
teller_set
.
insert
(
"swish"
);
teller_set
.
insert
(
"shuffle_channel"
);
}
bool
operator
()(
const
std
::
string
&
op_type
,
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
1237dfa6
...
...
@@ -896,4 +896,5 @@ USE_ANAKIN_CONVERTER(leaky_relu);
USE_ANAKIN_CONVERTER
(
affine_channel
);
USE_ANAKIN_CONVERTER
(
relu6
);
USE_ANAKIN_CONVERTER
(
swish
);
USE_ANAKIN_CONVERTER
(
shuffle_channel
);
#endif
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
1237dfa6
...
...
@@ -79,6 +79,7 @@ const std::vector<std::string> kAnakinSubgraphPasses({
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"shuffle_channel_detect_pass"
,
//
"anakin_subgraph_pass"
,
//
"fc_gru_fuse_pass"
,
//
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录