Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
603ba5e0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
603ba5e0
编写于
10月 19, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add seqconv eltadd relu pass
上级
23fc896b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
227 addition
and
11 deletion
+227
-11
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+50
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+25
-0
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
+101
-0
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h
+38
-0
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+12
-11
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
603ba5e0
...
@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
...
@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library
(
fc_gru_fuse_pass inference
)
pass_library
(
fc_gru_fuse_pass inference
)
pass_library
(
seq_concat_fc_fuse_pass inference
)
pass_library
(
seq_concat_fc_fuse_pass inference
)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
conv_relu_mkldnn_fuse_pass inference
)
pass_library
(
conv_relu_mkldnn_fuse_pass inference
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
603ba5e0
...
@@ -349,6 +349,11 @@ PDNode *PDNode::assert_is_op() {
...
@@ -349,6 +349,11 @@ PDNode *PDNode::assert_is_op() {
return
this
;
return
this
;
}
}
// PDNode *PDNode::assert_op_attr() {
// asserts_.emplace_back([](Node *x) { return x && x->IsOp(); });
// return this;
// }
PDNode
*
PDNode
::
assert_is_op
(
const
std
::
string
&
op_type
)
{
PDNode
*
PDNode
::
assert_is_op
(
const
std
::
string
&
op_type
)
{
asserts_
.
emplace_back
([
op_type
](
Node
*
x
)
{
asserts_
.
emplace_back
([
op_type
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
op_type
;
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
op_type
;
...
@@ -761,6 +766,51 @@ PDNode *patterns::ConvReLU::operator()(
...
@@ -761,6 +766,51 @@ PDNode *patterns::ConvReLU::operator()(
return
relu_out_var
;
return
relu_out_var
;
}
}
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
seqconv_input
)
{
// Create Operators
seqconv_input
->
assert_is_op_input
(
"sequence_conv"
,
"X"
);
auto
*
seqconv_op
=
pattern
->
NewNode
(
seqconv_repr
())
->
assert_is_op
(
"sequence_conv"
);
// ->assert_op_attr("paddingTrainable", false)
// ->assert_op_attr("contextStride", 1)
auto
*
eltadd_op
=
pattern
->
NewNode
(
eltadd_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
relu_op
=
pattern
->
NewNode
(
relu_repr
())
->
assert_is_op
(
"relu"
);
// Create variables
// Filter
auto
*
seqconv_weight_var
=
pattern
->
NewNode
(
seqconv_weight_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"sequence_conv"
,
"Filter"
);
// Bias
auto
*
eltadd_bias_var
=
pattern
->
NewNode
(
eltadd_bias_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
seqconv_out_var
=
pattern
->
NewNode
(
seqconv_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"sequence_conv"
)
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out_var
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"elementwise_add"
)
->
assert_is_only_input_of_op
(
"relu"
);
// output
auto
*
relu_out_var
=
pattern
->
NewNode
(
relu_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"relu"
);
seqconv_op
->
LinksFrom
({
seqconv_input
,
seqconv_weight_var
})
.
LinksTo
({
seqconv_out_var
});
eltadd_op
->
LinksFrom
({
seqconv_out_var
,
eltadd_bias_var
})
.
LinksTo
({
eltadd_out_var
});
relu_op
->
LinksFrom
({
eltadd_out_var
}).
LinksTo
({
relu_out_var
});
return
relu_out_var
;
}
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
bool
with_bias
)
{
bool
with_bias
)
{
// Create shared nodes.
// Create shared nodes.
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
603ba5e0
...
@@ -434,6 +434,31 @@ struct ConvReLU : public PatternBase {
...
@@ -434,6 +434,31 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE
(
relu_out
);
PATTERN_DECL_NODE
(
relu_out
);
};
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
// seqconv_input, seqconv_weight,
// seqconv_out, seqconv,
// elementwise_add_bias, elementwise_add_out, elementwise_add
// relu_out, relu
struct
SeqConvEltAddRelu
:
public
PatternBase
{
SeqConvEltAddRelu
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"seqconv_eltadd_relu"
)
{}
PDNode
*
operator
()(
PDNode
*
seqconv_input
);
// declare operator node's name
PATTERN_DECL_NODE
(
seqconv
);
PATTERN_DECL_NODE
(
eltadd
);
PATTERN_DECL_NODE
(
relu
);
// declare variable node's name
PATTERN_DECL_NODE
(
seqconv_weight
);
PATTERN_DECL_NODE
(
seqconv_out
);
PATTERN_DECL_NODE
(
eltadd_bias
);
PATTERN_DECL_NODE
(
eltadd_out
);
PATTERN_DECL_NODE
(
relu_out
);
};
// FC with bias
// FC with bias
// op: mul + elementwise_add
// op: mul + elementwise_add
// named nodes:
// named nodes:
...
...
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
0 → 100644
浏览文件 @
603ba5e0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
PDNodeName
(
name_scope
,
"X"
))
->
assert_is_op_input
(
"sequence_conv"
)
->
assert_var_not_persistable
();
patterns
::
SeqConvEltAddRelu
fuse_pattern
(
pattern
,
name_scope
);
fuse_pattern
(
x
);
// Create New OpDesc
auto
fuse_creator
=
[
&
](
Node
*
seqconv
,
Node
*
input
,
Node
*
seqconv_weight
,
Node
*
eltadd_bias
,
Node
*
relu_out
)
{
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_seqconv_eltadd_relu"
);
op_desc
.
SetInput
(
"X"
,
{
input
->
Name
()});
op_desc
.
SetInput
(
"Filter"
,
{
seqconv_weight
->
Name
()});
op_desc
.
SetInput
(
"Bias"
,
{
eltadd_bias
->
Name
()});
op_desc
.
SetAttr
(
"contextLength"
,
seqconv
->
Op
()
->
GetAttr
(
"contextLength"
));
op_desc
.
SetAttr
(
"contextStart"
,
seqconv
->
Op
()
->
GetAttr
(
"contextStart"
));
op_desc
.
SetAttr
(
"contextStride"
,
seqconv
->
Op
()
->
GetAttr
(
"contextStride"
));
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
const
std
::
string
ColMat
=
patterns
::
UniqueKey
(
"SeqConvColMat"
);
op_desc
.
SetOutput
(
"ColMat"
,
{
ColMat
});
op_desc
.
SetOutput
(
"Out"
,
{
relu_out
->
Name
()});
scope
->
Var
(
ColMat
)
->
GetMutable
<
LoDTensor
>
();
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
input
,
op
);
IR_NODE_LINK_TO
(
seqconv_weight
,
op
);
IR_NODE_LINK_TO
(
eltadd_bias
,
op
);
IR_NODE_LINK_TO
(
op
,
relu_out
);
return
op
;
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle SeqConv EltAdd Relu fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
seqconv
,
seqconv
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
seqconv_weight
,
seqconv_weight
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
seqconv_out
,
seqconv_out
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd
,
eltadd
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_bias
,
eltadd_bias
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
relu
,
relu
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
relu_out
,
relu_out
,
fuse_pattern
);
fuse_creator
(
seqconv
,
subgraph
.
at
(
x
),
seqconv_weight
,
eltadd_bias
,
relu_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
seqconv
,
seqconv_out
,
eltadd
,
eltadd_out
,
relu
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
SeqConvEltAddReluFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
,
param_scope
());
AddStatis
(
fusion_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
seqconv_eltadd_relu_fuse_pass
,
paddle
::
framework
::
ir
::
SeqConvEltAddReluFusePass
);
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h
0 → 100644
浏览文件 @
603ba5e0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
SeqConvEltAddReluFusePass
:
public
FusePassBase
{
public:
virtual
~
SeqConvEltAddReluFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"seqconv_eltadd_relu_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
603ba5e0
...
@@ -67,17 +67,18 @@ class Analyzer : public OrderedRegistry<PassManager> {
...
@@ -67,17 +67,18 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion.
// larger fusion.
const
std
::
vector
<
std
::
string
>
all_ir_passes_
{{
const
std
::
vector
<
std
::
string
>
all_ir_passes_
{{
// Manual update the passes here.
// Manual update the passes here.
"infer_clean_graph_pass"
,
//
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
"seqconv_eltadd_relu_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录