Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c0c9fcd9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c0c9fcd9
编写于
12月 16, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add source file
test=develop
上级
38895302
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
124 addition
and
0 deletion
+124
-0
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
+91
-0
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h
+33
-0
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
0 → 100644
浏览文件 @
c0c9fcd9
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out);
std
::
unique_ptr
<
ir
::
Graph
>
ConvElementwiseAddFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
pattern_name
=
"conv_elementwise_add_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsInput
();
patterns
::
ConvElementwiseadd
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
auto
base_op_desc
=
*
conv_op
->
Op
()
->
Proto
();
std
::
string
bias_name
=
elementwise_add_in_y
->
Name
();
std
::
string
output_name
=
elementwise_add_out
->
Name
();
std
::
string
act_type
=
"identity"
;
framework
::
OpDesc
new_op_desc
(
base_op_desc
,
nullptr
);
new_op_desc
.
SetType
(
"conv2d_fusion"
);
new_op_desc
.
SetInput
(
"Bias"
,
{
bias_name
});
new_op_desc
.
SetInput
(
"ResidualData"
,
{});
new_op_desc
.
SetAttr
(
"activation"
,
act_type
);
new_op_desc
.
SetOutput
(
"Output"
,
{
output_name
});
new_op_desc
.
SetAttr
(
"is_test"
,
true
);
new_op_desc
.
SetAttr
(
"use_cudnn"
,
false
);
new_op_desc
.
Flush
();
// Create a new node for the fused op.
auto
*
new_conv_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
// Link inputs and outputs.
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
auto
*
conv_in_node
=
subgraph
.
at
(
x
);
IR_NODE_LINK_TO
(
conv_in_node
,
new_conv_op
);
// Input
IR_NODE_LINK_TO
(
conv_filter
,
new_conv_op
);
// Filter
IR_NODE_LINK_TO
(
elementwise_add_in_y
,
new_conv_op
);
// Bias
IR_NODE_LINK_TO
(
new_conv_op
,
elementwise_add_out
);
// Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv_op
,
conv_out
,
elementwise_add_op
});
};
gpd
(
graph
.
get
(),
handler
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
conv_elementwise_add_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAddFusePass
);
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h
0 → 100644
浏览文件 @
c0c9fcd9
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
ConvElementwiseAddFusePass
:
public
FusePassBase
{
public:
virtual
~
ConvElementwiseAddFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录