Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5d2834fc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5d2834fc
编写于
8月 15, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
8月 15, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/ir support fuse, based on graph pattern detection helper (#12636)
上级
09e30eee
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
704 addition
and
0 deletion
+704
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+3
-0
paddle/fluid/framework/ir/graph_pattern_detecter.cc
paddle/fluid/framework/ir/graph_pattern_detecter.cc
+186
-0
paddle/fluid/framework/ir/graph_pattern_detecter.h
paddle/fluid/framework/ir/graph_pattern_detecter.h
+181
-0
paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
+172
-0
paddle/fluid/framework/ir/graph_traits.cc
paddle/fluid/framework/ir/graph_traits.cc
+69
-0
paddle/fluid/framework/ir/graph_traits.h
paddle/fluid/framework/ir/graph_traits.h
+90
-0
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+3
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
5d2834fc
...
...
@@ -3,7 +3,10 @@ cc_library(graph SRCS graph.cc DEPS node)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter
)
paddle/fluid/framework/ir/graph_pattern_detecter.cc
0 → 100644
浏览文件 @
5d2834fc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
nodes_
.
emplace_back
(
new
PDNode
(
std
::
move
(
teller
),
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
return
cur
;
}
void
PDPattern
::
AddEdge
(
PDNode
*
a
,
PDNode
*
b
)
{
PADDLE_ENFORCE
(
a
);
PADDLE_ENFORCE
(
b
);
PADDLE_ENFORCE
(
a
!=
b
,
"can't connect to the same nodes."
);
edges_
.
emplace_back
(
a
,
b
);
}
void
GraphPatternDetecter
::
operator
()(
Graph
*
graph
,
GraphPatternDetecter
::
handle_t
handler
)
{
if
(
!
MarkPDNodesInGraph
(
*
graph
))
return
;
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
for
(
auto
&
g
:
subgraphs
)
{
handler
(
g
,
graph
);
}
}
bool
GraphPatternDetecter
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
if
(
graph
.
Nodes
().
empty
())
return
false
;
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
for
(
const
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
pdnode
->
Tell
(
&
node
))
{
pdnodes2nodes_
[
pdnode
.
get
()].
insert
(
&
node
);
}
}
}
return
!
pdnodes2nodes_
.
empty
();
}
struct
HitGroup
{
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
bool
Match
(
Node
*
node
,
PDNode
*
pat
)
{
return
!
roles
.
count
(
pat
)
||
roles
.
at
(
pat
)
==
node
;
}
void
Register
(
Node
*
node
,
PDNode
*
pat
)
{
roles
[
pat
]
=
node
;
}
};
// Tell whether Node a links to b.
bool
IsNodesLink
(
Node
*
a
,
Node
*
b
)
{
for
(
auto
*
node
:
a
->
outputs
)
{
if
(
b
==
node
)
{
return
true
;
}
}
return
false
;
}
std
::
vector
<
GraphPatternDetecter
::
subgraph_t
>
GraphPatternDetecter
::
DetectPatterns
()
{
// Init empty subgraphs.
std
::
vector
<
GraphPatternDetecter
::
subgraph_t
>
result
;
std
::
vector
<
HitGroup
>
init_groups
;
PADDLE_ENFORCE
(
!
pattern_
.
edges
().
empty
(),
"At least one edge is needed"
);
auto
*
first_pnode
=
pattern_
.
edges
().
front
().
first
;
if
(
!
pdnodes2nodes_
.
count
(
first_pnode
))
return
result
;
for
(
auto
*
node
:
pdnodes2nodes_
[
first_pnode
])
{
HitGroup
group
;
group
.
roles
[
first_pnode
]
=
node
;
init_groups
.
emplace_back
(
group
);
}
int
step
=
0
;
std
::
array
<
std
::
vector
<
HitGroup
>
,
2
>
bi_records
;
bi_records
[
0
]
=
std
::
move
(
init_groups
);
// Extend a PDNode to subgraphs by deducing the connection relations defined
// in edges of PDNodes.
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
cur_groups
=
bi_records
[
1
-
(
step
++
%
2
)];
cur_groups
.
clear
();
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
// TODO(Superjomn) add some prune strategies.
for
(
const
auto
&
group
:
pre_groups
)
{
HitGroup
new_group
=
group
;
if
(
IsNodesLink
(
source
,
target
)
&&
new_group
.
Match
(
source
,
edge
.
first
))
{
new_group
.
Register
(
source
,
edge
.
first
);
if
(
new_group
.
Match
(
target
,
edge
.
second
))
{
new_group
.
Register
(
target
,
edge
.
second
);
cur_groups
.
push_back
(
new_group
);
// TODO(Superjomn) need to unique
}
}
}
}
}
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
GraphPatternDetecter
::
subgraph_t
subgraph
;
for
(
auto
&
role
:
group
.
roles
)
{
subgraph
.
emplace
(
role
.
first
,
role
.
second
);
}
result
.
emplace_back
(
subgraph
);
}
return
result
;
}
void
GraphPatternDetecter
::
UniquePatterns
(
std
::
vector
<
GraphPatternDetecter
::
subgraph_t
>*
subgraphs
)
{
if
(
subgraphs
->
empty
())
return
;
std
::
vector
<
GraphPatternDetecter
::
subgraph_t
>
result
;
std
::
unordered_set
<
size_t
>
set
;
for
(
auto
&
g
:
*
subgraphs
)
{
size_t
key
=
0
;
for
(
auto
&
item
:
g
)
{
key
^=
std
::
hash
<
void
*>
{}(
item
.
first
);
key
^=
std
::
hash
<
void
*>
{}(
item
.
second
);
}
if
(
!
set
.
count
(
key
))
{
result
.
emplace_back
(
g
);
set
.
insert
(
key
);
}
}
*
subgraphs
=
result
;
}
void
GraphPatternDetecter
::
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
subgraph_t
>
result
;
std
::
unordered_set
<
Node
*>
node_set
;
for
(
const
auto
&
subgraph
:
*
subgraphs
)
{
bool
valid
=
true
;
for
(
auto
&
item
:
subgraph
)
{
if
(
node_set
.
count
(
item
.
second
))
{
valid
=
false
;
break
;
}
}
if
(
valid
)
{
for
(
auto
&
item
:
subgraph
)
{
node_set
.
insert
(
item
.
second
);
}
result
.
push_back
(
subgraph
);
}
}
*
subgraphs
=
result
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detecter.h
0 → 100644
浏览文件 @
5d2834fc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_TESTING
#include <gtest/gtest_prod.h>
#endif
#include <numeric>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// Some basic torminolygies:
// - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`.
// - A pattern is defined with PDNodes with edges.
// Pattern detector node. This node helps to build a pattern.
struct
PDNode
{
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
PDNode
(
teller_t
&&
teller
,
const
std
::
string
&
name
=
""
)
:
teller_
(
teller
),
name_
(
name
)
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"invalid teller functer is set."
);
}
PDNode
(
PDNode
&&
other
)
=
default
;
std
::
vector
<
PDNode
*>
inlinks
;
std
::
vector
<
PDNode
*>
outlinks
;
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
return
teller_
(
node
);
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
private:
teller_t
teller_
;
std
::
string
name_
;
};
/*
* A pattern in a graph, which defined with PDNode and edges. Most graph
* patterns can be divided into PDNodes and link relations between them.
*
* For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD
* operators from the computation graph, the MUL's output should have only one
* consumer which is the ELEMENTWISE_ADD.
* This pattern can be defined as with the following pseudo codes
*
* // Create two operator PDNodes.
* MUL = PDPattern.NewNode()
* ELE = PDPattern.NewNode()
* // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode()
* // Add teller to define some rules that help to filter the target Nodes.
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
* ELE.teller = lambda(node): \
* node->IsOp() && node->Op()->Type == "elementwise_add";
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
* && (ELE in node->outputs)
*
* One can add more specific tellers for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.teller.
*
* PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
* - Ops whose inputs and outputs share the same variables
*/
class
PDPattern
{
public:
using
edge_t
=
std
::
pair
<
PDNode
*
,
PDNode
*>
;
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
""
);
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
private:
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
PDPattern
,
AddEdge
);
FRIEND_TEST
(
PDPattern
,
NewNode
);
#endif
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>
nodes_
;
std
::
vector
<
edge_t
>
edges_
;
};
/*
* GraphPatternDetecter helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
*
* The algorithm has three phases:
* 1. Mark the nodes that match the defined PDNodes in a PDPattern,
* 2. Extend a PDNode to subgraphs by deducing the connection relation defined
* in PAPattern(the edges),
* 3. Get the filtered subgraphs and treat them with a pre-defined handler.
*
* Usage:
* // Create a detector
* GraphPatternDetecter detector;
* // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...)
* node0->teller = some lambda.
* node1->teller = some lambda.
* detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns.
* GraphPatternDetecter::handle_t handler = some labmda
* // Execute the detector.
* detector(&graph, handler);
*/
class
GraphPatternDetecter
{
public:
using
subgraph_t
=
std
::
unordered_map
<
PDNode
*
,
Node
*>
;
// Operate on the detected pattern.
using
handle_t
=
std
::
function
<
void
(
const
subgraph_t
&
/*hitted pattern*/
,
Graph
*
)
>
;
void
operator
()(
Graph
*
graph
,
handle_t
handler
);
const
PDPattern
&
pattern
()
const
{
return
pattern_
;
}
PDPattern
*
mutable_pattern
()
{
return
&
pattern_
;
}
private:
// Mark the nodes that fits the pattern.
bool
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
);
// Detect all the pattern and output the hit records.
std
::
vector
<
subgraph_t
>
DetectPatterns
();
// Remove duplicate patterns.
void
UniquePatterns
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
#endif
private:
using
hit_rcd_t
=
std
::
pair
<
Node
*
/*node in graph*/
,
PDNode
*
/*node in pattern*/
>
;
PDPattern
pattern_
;
std
::
vector
<
hit_rcd_t
>
marked_records_
;
std
::
unordered_map
<
const
PDNode
*
,
std
::
unordered_set
<
Node
*>>
pdnodes2nodes_
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
0 → 100644
浏览文件 @
5d2834fc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
BuildGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o3
=
g
->
CreateEmptyNode
(
"op3"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o4
=
g
->
CreateEmptyNode
(
"op4"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o5
=
g
->
CreateEmptyNode
(
"op5"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v3
=
g
->
CreateEmptyNode
(
"var3"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v4
=
g
->
CreateEmptyNode
(
"var4"
,
Node
::
Type
::
kVariable
);
// o1->v1->o2
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
// o2->v2->o3
// o2->v2->o4
o2
->
outputs
.
push_back
(
v2
);
o3
->
inputs
.
push_back
(
v2
);
o4
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o3
);
v2
->
outputs
.
push_back
(
o4
);
// o2->v3->o5
o2
->
outputs
.
push_back
(
v3
);
o5
->
inputs
.
push_back
(
v3
);
v3
->
inputs
.
push_back
(
o2
);
v3
->
outputs
.
push_back
(
o5
);
// o3-v4->o5
o3
->
outputs
.
push_back
(
v4
);
o5
->
inputs
.
push_back
(
v4
);
v4
->
inputs
.
push_back
(
o3
);
v4
->
outputs
.
push_back
(
o5
);
}
TEST
(
PDPattern
,
NewNode
)
{
PDPattern
x
;
auto
*
n
=
x
.
NewNode
([](
Node
*
x
)
{
return
true
;
});
ASSERT_TRUE
(
n
);
ASSERT_EQ
(
x
.
nodes_
.
size
(),
1UL
);
}
TEST
(
PDPattern
,
AddEdge
)
{
PDPattern
x
;
auto
*
a
=
x
.
NewNode
([](
Node
*
x
)
{
return
true
;
});
auto
*
b
=
x
.
NewNode
([](
Node
*
x
)
{
return
true
;
});
ASSERT_TRUE
(
a
);
ASSERT_TRUE
(
b
);
x
.
AddEdge
(
a
,
b
);
ASSERT_EQ
(
x
.
nodes_
.
size
(),
2UL
);
ASSERT_EQ
(
x
.
edges_
.
size
(),
1UL
);
ASSERT_EQ
(
x
.
edges_
.
front
().
first
,
a
);
ASSERT_EQ
(
x
.
edges_
.
front
().
second
,
b
);
ASSERT_EQ
(
x
.
nodes
().
size
(),
2UL
);
ASSERT_EQ
(
x
.
edges
().
size
(),
1UL
);
ASSERT_EQ
(
x
.
edges
().
front
().
first
,
a
);
ASSERT_EQ
(
x
.
edges
().
front
().
second
,
b
);
}
TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
)
{
GraphPatternDetecter
x
;
// mark o2, o3, v2
// The pattern is a graph:
// o2(a node named o2) -> v2(a node named v2)
// v2 -> o3(a node named o3)
auto
*
o2
=
x
.
pattern_
.
NewNode
([](
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
Name
()
==
"op2"
&&
node
->
IsOp
();
});
auto
*
o3
=
x
.
pattern_
.
NewNode
([](
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
Name
()
==
"op3"
&&
node
->
IsOp
();
});
auto
*
v2
=
x
.
pattern_
.
NewNode
([](
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
Name
()
==
"var2"
&&
node
->
IsVar
();
});
ASSERT_FALSE
(
o2
->
Tell
(
nullptr
));
ASSERT_FALSE
(
o3
->
Tell
(
nullptr
));
ASSERT_FALSE
(
v2
->
Tell
(
nullptr
));
x
.
pattern_
.
AddEdge
(
o2
,
v2
);
x
.
pattern_
.
AddEdge
(
v2
,
o3
);
ASSERT_EQ
(
x
.
pattern_
.
edges
().
size
(),
2UL
);
ASSERT_EQ
(
x
.
pattern_
.
edges
()[
0
].
first
,
o2
);
ASSERT_EQ
(
x
.
pattern_
.
edges
()[
0
].
second
,
v2
);
ASSERT_EQ
(
x
.
pattern_
.
edges
()[
1
].
first
,
v2
);
ASSERT_EQ
(
x
.
pattern_
.
edges
()[
1
].
second
,
o3
);
ProgramDesc
program
;
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
x
.
MarkPDNodesInGraph
(
graph
);
ASSERT_EQ
(
x
.
pdnodes2nodes_
.
size
(),
3UL
);
auto
subgraphs
=
x
.
DetectPatterns
();
ASSERT_EQ
(
subgraphs
.
size
(),
1UL
);
}
TEST
(
GraphPatternDetecter
,
MultiSubgraph
)
{
ProgramDesc
program
;
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
GraphPatternDetecter
x
;
// The pattern is a graph:
// op -> var
auto
*
any_op
=
x
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
()
&&
(
node
->
Name
()
==
"op2"
||
node
->
Name
()
==
"op3"
);
},
"OP0"
);
auto
*
any_var
=
x
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
();
},
"VAR"
);
auto
*
any_op1
=
x
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
();
},
"OP1"
);
x
.
mutable_pattern
()
->
AddEdge
(
any_op
,
any_var
);
x
.
mutable_pattern
()
->
AddEdge
(
any_var
,
any_op1
);
int
count
=
0
;
GraphPatternDetecter
::
handle_t
handle
=
[
&
](
const
GraphPatternDetecter
::
subgraph_t
&
s
,
Graph
*
g
)
{
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_var
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_op1
)
->
Name
();
count
++
;
};
x
(
&
graph
,
handle
);
// 1. Detect op3 -> var4 -> op5
// 2. Detect op2 -> var2 -> op3
// 3. Detect op2 -> var2 -> op4
// 4. Detect op2 -> var3 -> op5
// But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2
ASSERT_GE
(
count
,
1UL
);
ASSERT_LE
(
count
,
2UL
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_traits.cc
0 → 100644
浏览文件 @
5d2834fc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
//
// NodesDFSIterator
//
NodesDFSIterator
::
NodesDFSIterator
(
const
std
::
vector
<
Node
*>
&
source
)
{
for
(
auto
*
x
:
source
)
stack_
.
push
(
x
);
}
NodesDFSIterator
::
NodesDFSIterator
(
NodesDFSIterator
&&
other
)
noexcept
:
stack_
(
std
::
move
(
other
.
stack_
)),
visited_
(
std
::
move
(
other
.
visited_
))
{}
NodesDFSIterator
::
NodesDFSIterator
(
const
NodesDFSIterator
&
other
)
:
stack_
(
other
.
stack_
),
visited_
(
other
.
visited_
)
{}
Node
&
NodesDFSIterator
::
operator
*
()
{
PADDLE_ENFORCE
(
!
stack_
.
empty
());
return
*
stack_
.
top
();
}
NodesDFSIterator
&
NodesDFSIterator
::
operator
++
()
{
PADDLE_ENFORCE
(
!
stack_
.
empty
(),
"the iterator exceeds range"
);
visited_
.
insert
(
stack_
.
top
());
auto
*
cur
=
stack_
.
top
();
stack_
.
pop
();
for
(
auto
*
x
:
cur
->
outputs
)
{
if
(
!
visited_
.
count
(
x
))
{
stack_
.
push
(
x
);
}
}
return
*
this
;
}
bool
NodesDFSIterator
::
operator
==
(
const
NodesDFSIterator
&
other
)
{
if
(
stack_
.
empty
())
return
other
.
stack_
.
empty
();
if
((
!
stack_
.
empty
())
&&
(
!
other
.
stack_
.
empty
()))
{
return
stack_
.
top
()
==
other
.
stack_
.
top
();
}
return
false
;
}
NodesDFSIterator
&
NodesDFSIterator
::
operator
=
(
const
NodesDFSIterator
&
other
)
{
stack_
=
other
.
stack_
;
visited_
=
other
.
visited_
;
return
*
this
;
}
Node
*
NodesDFSIterator
::
operator
->
()
{
return
stack_
.
top
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_traits.h
0 → 100644
浏览文件 @
5d2834fc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stack>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
template
<
typename
IteratorT
>
class
iterator_range
{
IteratorT
begin_
,
end_
;
public:
template
<
typename
Container
>
explicit
iterator_range
(
Container
&&
c
)
:
begin_
(
c
.
begin
()),
end_
(
c
.
end
())
{}
iterator_range
(
const
IteratorT
&
begin
,
const
IteratorT
&
end
)
:
begin_
(
begin
),
end_
(
end
)
{}
const
IteratorT
&
begin
()
const
{
return
begin_
;
}
const
IteratorT
&
end
()
const
{
return
end_
;
}
};
// DFS iterator on nodes.
struct
NodesDFSIterator
:
public
std
::
iterator
<
std
::
forward_iterator_tag
,
Node
*>
{
NodesDFSIterator
()
=
default
;
explicit
NodesDFSIterator
(
const
std
::
vector
<
Node
*>
&
source
);
NodesDFSIterator
(
NodesDFSIterator
&&
other
)
noexcept
;
NodesDFSIterator
(
const
NodesDFSIterator
&
other
);
Node
&
operator
*
();
NodesDFSIterator
&
operator
++
();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesDFSIterator
&
operator
=
(
const
NodesDFSIterator
&
other
);
bool
operator
==
(
const
NodesDFSIterator
&
other
);
bool
operator
!=
(
const
NodesDFSIterator
&
other
)
{
return
!
(
*
this
==
other
);
}
Node
*
operator
->
();
private:
std
::
stack
<
Node
*>
stack_
;
std
::
unordered_set
<
Node
*>
visited_
;
};
/*
* GraphTraits contains some graph traversal algorithms.
*
* Usage:
*
*/
struct
GraphTraits
{
static
iterator_range
<
NodesDFSIterator
>
DFS
(
const
Graph
&
g
)
{
auto
start_points
=
ExtractStartPoints
(
g
);
NodesDFSIterator
x
(
start_points
);
return
iterator_range
<
NodesDFSIterator
>
(
NodesDFSIterator
(
start_points
),
NodesDFSIterator
());
}
private:
// The nodes those have no input will be treated as start points.
static
std
::
vector
<
Node
*>
ExtractStartPoints
(
const
Graph
&
g
)
{
std
::
vector
<
Node
*>
result
;
for
(
auto
*
node
:
g
.
Nodes
())
{
if
(
node
->
inputs
.
empty
())
{
result
.
push_back
(
node
);
}
}
return
result
;
}
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/node.h
浏览文件 @
5d2834fc
...
...
@@ -58,6 +58,9 @@ class Node {
return
op_desc_
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOperation
;
}
bool
IsVar
()
const
{
return
type_
==
Type
::
kVariable
;
}
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录