Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
167523e7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
167523e7
编写于
7月 28, 2021
作者:
J
jiangcheng
提交者:
GitHub
7月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
graph_to_program topology sort (#33949)
See
https://github.com/PaddlePaddle/Paddle/pull/33949
for details
上级
f1654de6
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
762 addition
and
33 deletion
+762
-33
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+33
-7
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+27
-5
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+80
-0
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+2
-0
paddle/fluid/framework/ir/graph_to_program_pass.cc
paddle/fluid/framework/ir/graph_to_program_pass.cc
+87
-15
paddle/fluid/framework/ir/graph_to_program_pass.h
paddle/fluid/framework/ir/graph_to_program_pass.h
+3
-0
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
+382
-0
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
...k/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
+11
-0
paddle/fluid/framework/ir/node.cc
paddle/fluid/framework/ir/node.cc
+1
-1
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+109
-5
paddle/fluid/framework/ir/node_test.cc
paddle/fluid/framework/ir/node_test.cc
+27
-0
未找到文件。
paddle/fluid/framework/ir/graph.cc
浏览文件 @
167523e7
...
...
@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
// sub_graph.
std
::
unique_ptr
<
Graph
>
first_sub_graph
=
std
::
make_unique
<
Graph
>
(
program_
.
Block
(
0
),
this
,
start_op_index
,
end_op_index
);
first_sub_graph
->
block_id_
=
0
;
sub_graphs_
.
push_back
(
std
::
move
(
first_sub_graph
));
for
(
size_t
idx
=
1
;
idx
<
program_
.
Size
();
++
idx
)
{
std
::
unique_ptr
<
Graph
>
sub_graph
=
std
::
make_unique
<
Graph
>
(
program_
.
Block
(
idx
),
this
);
sub_graph
->
block_id_
=
idx
;
sub_graphs_
.
push_back
(
std
::
move
(
sub_graph
));
}
}
else
{
...
...
@@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
Graph
::
InitFromBlock
(
const
BlockDesc
&
block
,
const
int64_t
start_op_index
,
const
int64_t
end_op_index
)
{
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarDesc
*
,
int
>>
name_to_desc_block_id
;
const
BlockDesc
*
block_var_visible
=
&
block
;
while
(
block_var_visible
!=
nullptr
)
{
for
(
auto
*
var
:
block_var_visible
->
AllVars
())
{
name_to_desc_block_id
.
emplace
(
var
->
Name
(),
std
::
make_pair
(
var
,
block_var_visible
->
ID
()));
}
const
BlockDesc
*
forward_block
=
block_var_visible
->
ForwardBlock
();
if
(
forward_block
!=
nullptr
)
{
for
(
auto
*
var
:
forward_block
->
AllVars
())
{
name_to_desc_block_id
.
emplace
(
var
->
Name
(),
std
::
make_pair
(
var
,
forward_block
->
ID
()));
}
}
block_var_visible
=
block_var_visible
->
ParentBlock
();
}
// var nodes for each var name, will have multiple versions in SSA
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
not_visited_vars
;
for
(
auto
*
var
:
block
.
AllVars
())
{
all
_vars
.
emplace
(
var
->
Name
(),
var
);
not_visited
_vars
.
emplace
(
var
->
Name
(),
var
);
}
auto
not_visited_vars
=
all_vars
;
int
desc_order
=
0
;
auto
all_ops
=
block
.
AllOps
();
PADDLE_ENFORCE_LE
(
end_op_index
,
all_ops
.
size
(),
...
...
@@ -109,6 +129,8 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
auto
*
op
=
all_ops
[
i
];
VLOG
(
3
)
<<
"create OpNode by "
<<
op
->
Type
();
ir
::
Node
*
node
=
CreateOpNode
(
op
);
node
->
SetDescOrder
(
desc_order
);
++
desc_order
;
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
...
...
@@ -116,8 +138,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
ir
::
Node
*
var
=
nullptr
;
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
var
=
var_nodes
.
at
(
each_var_name
).
back
();
}
else
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
}
else
if
(
name_to_desc_block_id
.
count
(
each_var_name
)
!=
0
)
{
auto
desc_and_block_id
=
name_to_desc_block_id
.
at
(
each_var_name
);
var
=
CreateVarNode
(
desc_and_block_id
.
first
,
desc_and_block_id
.
second
);
var_nodes
[
each_var_name
].
push_back
(
var
);
}
else
{
// Operation input var can be optional (dispensable). Which means
...
...
@@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
}
ir
::
Node
*
var
=
nullptr
;
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
if
(
name_to_desc_block_id
.
count
(
each_var_name
)
!=
0
)
{
auto
desc_and_block_id
=
name_to_desc_block_id
.
at
(
each_var_name
);
var
=
CreateVarNode
(
desc_and_block_id
.
first
,
desc_and_block_id
.
second
);
}
else
{
// Operation output vars can be @EMPTY@. For example, while_grad
// can have multi @EMPTY@ outputs with no VarDesc.
...
...
@@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
auto
cloned_graph
=
std
::
make_shared
<
Graph
>
(
this
->
program_
);
cloned_graph
->
ReleaseNodes
();
cloned_graph
->
num_node_created_
=
0
;
cloned_graph
->
block_id_
=
this
->
block_id_
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
for
(
auto
*
n
:
this
->
node_set_
)
{
PADDLE_ENFORCE_NOT_NULL
(
n
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
std
::
make_unique
<
Graph
>
(
this
->
program_
.
Block
(
idx
),
this
);
cloned_sub_graph
->
ReleaseNodes
();
cloned_sub_graph
->
num_node_created_
=
0
;
cloned_sub_graph
->
block_id_
=
idx
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
for
(
auto
*
n
:
this
->
sub_graphs_
.
at
(
idx
)
->
Nodes
())
{
PADDLE_ENFORCE_NOT_NULL
(
n
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
167523e7
...
...
@@ -104,7 +104,14 @@ class Graph {
attr_dels_
.
clear
();
}
bool
IsConstructedByPartialProgram
()
const
{
return
is_partial_
;
}
bool
IsConstructedByPartialProgram
()
const
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
IsMainGraph
())
{
return
GetSubGraph
(
0
)
->
IsConstructedByPartialProgram
();
}
}
return
is_partial_
;
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
if
(
FLAGS_convert_all_blocks
)
{
...
...
@@ -210,7 +217,7 @@ class Graph {
}
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
,
int
block_id
=
-
1
)
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
IsMainGraph
())
{
return
GetSubGraph
(
0
)
->
CreateVarNode
(
var_desc
);
...
...
@@ -219,7 +226,8 @@ class Graph {
PADDLE_ENFORCE_NOT_NULL
(
var_desc
,
platform
::
errors
::
InvalidArgument
(
"The VarDesc used to create variable node is null."
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
var_desc
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
var_desc
,
block_id
==
-
1
?
block_id_
:
block_id
));
x
->
SetId
(
num_node_created_
++
);
return
x
;
}
...
...
@@ -252,7 +260,7 @@ class Graph {
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
static_cast
<
const
char
*>
(
ir
::
Node
::
kControlDepVarName
),
num_node_created_
);
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
,
block_id_
));
x
->
SetId
(
num_node_created_
++
);
return
x
;
}
...
...
@@ -265,7 +273,7 @@ class Graph {
return
GetSubGraph
(
0
)
->
CreateEmptyNode
(
name
,
type
);
}
}
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
type
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
type
,
block_id_
));
x
->
SetId
(
num_node_created_
++
);
return
x
;
}
...
...
@@ -365,6 +373,15 @@ class Graph {
return
sub_graphs_
.
at
(
idx
).
get
();
}
int
GetBlockId
()
const
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
IsMainGraph
())
{
return
GetSubGraph
(
0
)
->
block_id_
;
}
}
return
block_id_
;
}
size_t
SubGraphsSize
()
const
{
PADDLE_ENFORCE_EQ
(
this
->
IsMainGraph
(),
true
,
...
...
@@ -394,6 +411,9 @@ class Graph {
PADDLE_ENFORCE_EQ
(
this
->
IsMainGraph
(),
true
,
platform
::
errors
::
InvalidArgument
(
"This graph is not main_graph"
));
PADDLE_ENFORCE_EQ
(
sub_graphs_
.
size
(),
sub_graph
->
block_id_
,
platform
::
errors
::
InvalidArgument
(
"sub_graph idx is not equal to block_id_"
));
sub_graphs_
.
push_back
(
std
::
move
(
sub_graph
));
}
...
...
@@ -416,6 +436,8 @@ class Graph {
// parts: forward graph and backward graph, which can be executed
// independently.
bool
is_partial_
{
false
};
// The block this SubGraph belongs to.
int
block_id_
{
0
};
};
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
167523e7
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include <queue>
#include <stack>
DEFINE_string
(
print_sub_graph_dir
,
""
,
...
...
@@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
}
}
class
DescOrderComparator
{
public:
bool
operator
()(
const
Node
*
n1
,
const
Node
*
n2
)
{
return
(
n1
->
DescOrder
()
>
n2
->
DescOrder
())
||
((
n1
->
DescOrder
()
==
n2
->
DescOrder
())
&&
(
n1
->
ToString
()
>
n2
->
ToString
()));
}
};
std
::
vector
<
ir
::
Node
*>
TopologySortGraphByDescOrder
(
const
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
std
::
priority_queue
<
Node
*
,
std
::
vector
<
Node
*>
,
DescOrderComparator
>
q
;
std
::
unordered_map
<
Node
*
,
std
::
unordered_set
<
Node
*>>
in_ops
;
std
::
unordered_map
<
Node
*
,
std
::
unordered_set
<
Node
*>>
out_ops
;
// ensure all op node in 'in_ops' and 'out_ops'
for
(
const
auto
&
n
:
graph
.
Nodes
())
{
if
(
!
n
->
IsOp
())
continue
;
in_ops
.
emplace
(
n
,
std
::
unordered_set
<
Node
*>
());
out_ops
.
emplace
(
n
,
std
::
unordered_set
<
Node
*>
());
}
// record all op's input op and output op
for
(
const
auto
&
n
:
graph
.
Nodes
())
{
if
(
!
n
->
IsOp
())
continue
;
// traverse all input op
for
(
const
auto
&
var
:
n
->
inputs
)
{
for
(
const
auto
&
in
:
var
->
inputs
)
{
// use at instead of [] to prevent no unrecorded op node
in_ops
.
at
(
n
).
insert
(
in
);
out_ops
.
at
(
in
).
insert
(
n
);
}
}
}
// find topology entrance
for
(
const
auto
&
n
:
graph
.
Nodes
())
{
if
(
!
n
->
IsOp
())
continue
;
if
(
in_ops
.
at
(
n
).
empty
())
{
q
.
push
(
n
);
}
}
// topological sorting
while
(
!
q
.
empty
())
{
// Do not get by reference!!! The element will pop later.
const
auto
cur_op
=
q
.
top
();
q
.
pop
();
sorted_ops
.
push_back
(
cur_op
);
for
(
const
auto
&
out
:
out_ops
.
at
(
cur_op
))
{
PADDLE_ENFORCE_GT
(
in_ops
.
at
(
out
).
count
(
cur_op
),
0
,
platform
::
errors
::
InvalidArgument
(
"We find %s in %s's output list, "
"but cannot find %s in %s's input list. "
"Please ensure graph completely."
,
out
->
Name
().
c_str
(),
cur_op
->
Name
().
c_str
(),
cur_op
->
Name
().
c_str
(),
out
->
Name
().
c_str
()));
in_ops
.
at
(
out
).
erase
(
cur_op
);
// push if in-degree is 0
if
(
in_ops
.
at
(
out
).
empty
())
{
q
.
push
(
out
);
}
}
}
PADDLE_ENFORCE_EQ
(
sorted_ops
.
size
(),
in_ops
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Topological sorting incompletely, "
"only sorted %zd op but total %zd."
,
sorted_ops
.
size
(),
in_ops
.
size
()));
return
sorted_ops
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
167523e7
...
...
@@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
return
ret
;
}
std
::
vector
<
ir
::
Node
*>
TopologySortGraphByDescOrder
(
const
Graph
&
graph
);
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_to_program_pass.cc
浏览文件 @
167523e7
...
...
@@ -14,7 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <gflags/gflags.h>
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool
(
convert_all_blocks
);
namespace
paddle
{
namespace
framework
{
...
...
@@ -27,13 +33,10 @@ namespace framework {
namespace
ir
{
void
GraphToProgramPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
// Remove the unneeded variables after memory optimization.
std
::
unordered_set
<
std
::
string
>
vars2remove
;
if
(
graph
->
Has
(
kGraphToProgramVarsToRemove
))
{
vars2remove
=
graph
->
Get
<
std
::
unordered_set
<
std
::
string
>>
(
kGraphToProgramVarsToRemove
);
VLOG
(
2
)
<<
"graph to program remove "
<<
vars2remove
.
size
()
<<
" nodes"
;
}
PADDLE_ENFORCE_EQ
(
graph
->
IsMainGraph
(),
true
,
platform
::
errors
::
InvalidArgument
(
"This graph is a sub_graph, "
"and can't convert to program individually"
));
ProgramDesc
&
program
=
Get
<
ProgramDesc
>
(
"program"
);
...
...
@@ -42,12 +45,79 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
auto
block
=
program_pb
->
mutable_blocks
(
kRootBlockIndex
);
block
->
set_idx
(
kRootBlockIndex
);
if
(
FLAGS_convert_all_blocks
)
{
GraphToBlock
(
graph
->
GetSubGraph
(
kRootBlockIndex
),
block
);
VLOG
(
3
)
<<
"Graph to program need convert "
<<
graph
->
SubGraphsSize
()
<<
" sub graph"
;
for
(
size_t
idx
=
0
;
idx
<
graph
->
SubGraphsSize
();
++
idx
)
{
// avoid kRootBlockIndex not 0
if
(
idx
==
kRootBlockIndex
)
continue
;
block
=
program_pb
->
add_blocks
();
block
->
set_idx
(
idx
);
GraphToBlock
(
graph
->
GetSubGraph
(
idx
),
block
);
}
}
else
{
GraphToBlock
(
graph
,
block
);
}
program
.
CopyFrom
(
*
program_pb
);
}
OpDesc
*
ReplaceScaleLossGradOp
(
ir
::
Node
*
node
,
OpDesc
*
desc
)
{
desc
->
SetType
(
"fill_constant"
);
desc
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)));
desc
->
SetAttr
(
"value"
,
1.0
f
);
std
::
vector
<
std
::
string
>
output_names
;
for
(
auto
out
:
node
->
outputs
)
{
output_names
.
emplace_back
(
out
->
Name
());
}
desc
->
SetOutput
(
"Out"
,
output_names
);
return
desc
;
}
std
::
vector
<
OpDesc
>*
GetGraphOpDesc
(
const
std
::
vector
<
ir
::
Node
*>&
nodes
,
std
::
vector
<
OpDesc
>*
ops
)
{
for
(
ir
::
Node
*
n
:
nodes
)
{
// if node is not Op, skip
if
(
!
n
->
IsOp
())
continue
;
// create fill_constant op
if
(
n
->
Name
()
==
"scale_loss_grad"
)
{
ops
->
emplace_back
();
auto
&
desc
=
ops
->
back
();
ReplaceScaleLossGradOp
(
n
,
&
desc
);
}
else
if
(
n
->
Op
())
{
ops
->
emplace_back
(
*
n
->
Op
());
}
else
{
// delete no OpDesc op
}
}
return
ops
;
}
void
GraphToProgramPass
::
GraphToBlock
(
const
Graph
*
graph
,
proto
::
BlockDesc
*
block
)
const
{
// Remove the unneeded variables after memory optimization.
std
::
unordered_set
<
std
::
string
>
vars2remove
;
if
(
graph
->
Has
(
kGraphToProgramVarsToRemove
))
{
vars2remove
=
graph
->
Get
<
std
::
unordered_set
<
std
::
string
>>
(
kGraphToProgramVarsToRemove
);
VLOG
(
2
)
<<
"graph (id: "
<<
block
->
idx
()
<<
") to program remove "
<<
vars2remove
.
size
()
<<
" nodes"
;
}
block
->
clear_vars
();
std
::
unordered_set
<
std
::
string
>
visited_vars
;
for
(
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
IsVar
())
{
if
(
n
->
Var
()
&&
visited_vars
.
count
(
n
->
Var
()
->
Name
())
==
0
&&
!
vars2remove
.
count
(
n
->
Var
()
->
Name
()))
{
!
vars2remove
.
count
(
n
->
Var
()
->
Name
())
&&
n
->
GetVarNodeBlockId
()
==
graph
->
GetBlockId
())
{
visited_vars
.
insert
(
n
->
Var
()
->
Name
());
block
->
add_vars
()
->
MergeFrom
(
*
n
->
Var
()
->
Proto
());
}
...
...
@@ -62,16 +132,18 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
nodes
=
TopologyVarientSort
(
*
graph
,
static_cast
<
framework
::
ir
::
SortKind
>
(
sort_kind
));
}
else
{
nodes
=
TopologySortOperations
(
*
graph
);
if
(
FLAGS_convert_all_blocks
)
{
nodes
=
TopologySortGraphByDescOrder
(
*
graph
);
}
else
{
nodes
=
TopologySortOperations
(
*
graph
);
}
}
for
(
ir
::
Node
*
n
:
nodes
)
{
if
(
!
n
->
Op
())
continue
;
block
->
add_ops
()
->
MergeFrom
(
*
n
->
Op
()
->
Proto
());
std
::
vector
<
OpDesc
>
ops
;
GetGraphOpDesc
(
nodes
,
&
ops
)
;
for
(
auto
&
op
:
ops
)
{
block
->
add_ops
()
->
MergeFrom
(
*
op
.
Proto
());
}
program
.
CopyFrom
(
*
program_pb
);
}
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_to_program_pass.h
浏览文件 @
167523e7
...
...
@@ -29,6 +29,9 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class
GraphToProgramPass
:
public
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
GraphToBlock
(
const
Graph
*
graph
,
proto
::
BlockDesc
*
block
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
浏览文件 @
167523e7
...
...
@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <algorithm>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -103,6 +109,382 @@ TEST(GraphToProgramPass, Basic) {
EXPECT_TRUE
(
vars
.
find
(
"var2"
)
!=
vars
.
end
());
EXPECT_TRUE
(
vars
.
find
(
"var3"
)
!=
vars
.
end
());
}
void
BuildProgramWithMultiBlock
(
ProgramDesc
*
program
)
{
auto
*
global_block
=
program
->
MutableBlock
(
0
);
auto
*
mul_1_x
=
global_block
->
Var
(
"Mul_1_X"
);
mul_1_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_1_x
->
SetLoDLevel
(
0
);
mul_1_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_1_x
->
SetShape
({
1000
,
784
});
auto
*
mul_1_y
=
global_block
->
Var
(
"Mul_1_Y"
);
mul_1_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_1_y
->
SetLoDLevel
(
0
);
mul_1_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_1_y
->
SetShape
({
784
,
100
});
auto
*
mul_1_out
=
global_block
->
Var
(
"Mul_1_Out"
);
mul_1_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
auto
*
mul_op_1
=
global_block
->
AppendOp
();
mul_op_1
->
SetType
(
"mul"
);
mul_op_1
->
SetInput
(
"X"
,
{
mul_1_x
->
Name
()});
mul_op_1
->
SetInput
(
"Y"
,
{
mul_1_y
->
Name
()});
mul_op_1
->
SetOutput
(
"Y"
,
{
mul_1_out
->
Name
()});
// building cond op such as less_than
auto
*
less_than_op_1
=
global_block
->
AppendOp
();
less_than_op_1
->
SetType
(
"less_than"
);
auto
*
less_than_1_x
=
global_block
->
Var
(
"Less_than_1_X"
);
less_than_1_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_1_x
->
SetLoDLevel
(
0
);
less_than_1_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_1_x
->
SetShape
({
1
});
auto
*
less_than_1_y
=
global_block
->
Var
(
"Less_than_1_Y"
);
less_than_1_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_1_y
->
SetLoDLevel
(
0
);
less_than_1_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_1_y
->
SetShape
({
1
});
auto
*
less_than_1_out
=
global_block
->
Var
(
"Less_than_1_Out"
);
less_than_1_out
->
SetType
(
proto
::
VarType
::
BOOL
);
less_than_op_1
->
SetInput
(
"X"
,
{
less_than_1_x
->
Name
()});
less_than_op_1
->
SetInput
(
"Y"
,
{
less_than_1_y
->
Name
()});
less_than_op_1
->
SetOutput
(
"Out"
,
{
less_than_1_out
->
Name
()});
BlockDesc
*
sub_block
=
program
->
AppendBlock
(
*
global_block
);
std
::
vector
<
BlockDesc
*>
sub_blocks
;
sub_blocks
.
push_back
(
sub_block
);
BlockDesc
*
sub_block2
=
program
->
AppendBlock
(
*
sub_block
);
// for testing nested case.
sub_blocks
.
push_back
(
sub_block2
);
// building while op in sub_block
auto
*
while_op
=
global_block
->
AppendOp
();
while_op
->
SetType
(
"while"
);
while_op
->
SetAttr
(
"sub_block"
,
sub_blocks
[
0
]);
auto
*
while_x
=
global_block
->
Var
(
"While_X"
);
while_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
while_x
->
SetLoDLevel
(
0
);
while_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
while_x
->
SetShape
({
1
});
while_op
->
SetInput
(
"kX"
,
{
while_x
->
Name
()});
while_op
->
SetInput
(
"kCondition"
,
{
less_than_1_out
->
Name
()});
auto
*
while_out
=
global_block
->
Var
(
"While_Out"
);
while_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
while_out
->
SetLoDLevel
(
0
);
while_out
->
SetDataType
(
proto
::
VarType
::
FP32
);
while_out
->
SetShape
({
1
});
auto
*
steps
=
global_block
->
Var
(
"StepScopes"
);
while_op
->
SetOutput
(
"kOutputs"
,
{
while_out
->
Name
()});
while_op
->
SetOutput
(
"kStepScopes"
,
{
steps
->
Name
()});
auto
*
mul_2_x
=
global_block
->
Var
(
"Mul_2_X"
);
mul_2_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_2_x
->
SetLoDLevel
(
0
);
mul_2_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_2_x
->
SetShape
({
1000
,
784
});
auto
*
mul_2_y
=
global_block
->
Var
(
"Mul_2_Y"
);
mul_2_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_2_y
->
SetLoDLevel
(
0
);
mul_2_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_2_y
->
SetShape
({
784
,
100
});
auto
*
mul_op_2
=
sub_blocks
[
0
]
->
AppendOp
();
mul_op_2
->
SetType
(
"mul"
);
mul_op_2
->
SetInput
(
"X"
,
{
mul_2_x
->
Name
()});
mul_op_2
->
SetInput
(
"Y"
,
{
mul_2_y
->
Name
()});
auto
*
mul_2_out
=
global_block
->
Var
(
"Mul_2_Out"
);
mul_2_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_op_2
->
SetOutput
(
"Y"
,
{
mul_2_out
->
Name
()});
auto
*
less_than_op_2
=
sub_blocks
[
0
]
->
AppendOp
();
less_than_op_2
->
SetType
(
"less_than"
);
auto
*
less_than_2_x
=
global_block
->
Var
(
"Less_than_2_X"
);
less_than_2_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_2_x
->
SetLoDLevel
(
0
);
less_than_2_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_2_x
->
SetShape
({
1
});
auto
*
less_than_2_y
=
global_block
->
Var
(
"Less_than_2_Y"
);
less_than_2_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_2_y
->
SetLoDLevel
(
0
);
less_than_2_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_2_y
->
SetShape
({
1
});
less_than_op_2
->
SetInput
(
"X"
,
{
less_than_2_x
->
Name
()});
less_than_op_2
->
SetInput
(
"Y"
,
{
less_than_2_y
->
Name
()});
auto
*
less_than_2_out
=
global_block
->
Var
(
"Less_than_2_Out"
);
less_than_2_out
->
SetType
(
proto
::
VarType
::
BOOL
);
less_than_op_2
->
SetOutput
(
"Out"
,
{
less_than_2_out
->
Name
()});
auto
*
cond_op
=
sub_blocks
[
0
]
->
AppendOp
();
cond_op
->
SetType
(
"conditional_block"
);
cond_op
->
SetAttr
(
"sub_block"
,
sub_blocks
[
1
]);
auto
*
cond_x
=
sub_blocks
[
0
]
->
Var
(
"Cond_X"
);
cond_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
cond_x
->
SetLoDLevel
(
0
);
cond_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
cond_x
->
SetShape
({
1
});
cond_op
->
SetInput
(
"kInputs"
,
{
cond_x
->
Name
()});
cond_op
->
SetInput
(
"kCondition"
,
{
less_than_2_out
->
Name
()});
auto
*
cond_out
=
sub_blocks
[
0
]
->
Var
(
"Cond_Out"
);
cond_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
cond_out
->
SetLoDLevel
(
0
);
cond_out
->
SetDataType
(
proto
::
VarType
::
FP32
);
cond_out
->
SetShape
({
1
});
auto
*
scope
=
sub_blocks
[
0
]
->
Var
(
"Scope"
);
scope
->
SetType
(
proto
::
VarType
::
STEP_SCOPES
);
cond_op
->
SetOutput
(
"kOutputs"
,
{
cond_out
->
Name
()});
cond_op
->
SetOutput
(
"kScope"
,
{
scope
->
Name
()});
auto
*
mul_3_x
=
global_block
->
Var
(
"Mul_3_X"
);
mul_3_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_3_x
->
SetLoDLevel
(
0
);
mul_3_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_3_x
->
SetShape
({
1000
,
784
});
auto
*
mul_3_y
=
global_block
->
Var
(
"Mul_3_Y"
);
mul_3_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_3_y
->
SetLoDLevel
(
0
);
mul_3_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_3_y
->
SetShape
({
784
,
100
});
auto
*
mul_3_out
=
global_block
->
Var
(
"Mul_3_Out"
);
mul_3_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
auto
*
mul_op_3
=
sub_blocks
[
1
]
->
AppendOp
();
mul_op_3
->
SetType
(
"mul"
);
mul_op_3
->
SetInput
(
"X"
,
{
mul_3_x
->
Name
()});
mul_op_3
->
SetInput
(
"Y"
,
{
mul_3_y
->
Name
()});
mul_op_3
->
SetOutput
(
"Y"
,
{
mul_3_out
->
Name
()});
}
bool
VarComparator
(
const
VarDesc
*
a
,
const
VarDesc
*
b
)
{
return
a
->
Name
()
<
b
->
Name
();
}
void
CheckBlockVarsEqual
(
const
BlockDesc
&
before_block
,
const
BlockDesc
&
after_block
)
{
auto
before_vars
=
before_block
.
AllVars
();
auto
after_vars
=
after_block
.
AllVars
();
EXPECT_EQ
(
before_vars
.
size
(),
after_vars
.
size
());
// var's order is unimportant
std
::
sort
(
before_vars
.
begin
(),
before_vars
.
end
(),
VarComparator
);
std
::
sort
(
after_vars
.
begin
(),
after_vars
.
end
(),
VarComparator
);
for
(
size_t
var_idx
=
0
;
var_idx
<
before_vars
.
size
();
++
var_idx
)
{
const
auto
&
before_var
=
before_vars
.
at
(
var_idx
);
const
auto
&
after_var
=
after_vars
.
at
(
var_idx
);
EXPECT_EQ
(
before_var
->
Name
(),
after_var
->
Name
());
EXPECT_EQ
(
before_var
->
GetType
(),
after_var
->
GetType
());
}
}
void
CheckOpInputsEqual
(
const
OpDesc
*
before_op
,
const
OpDesc
*
after_op
)
{
const
auto
&
before_inputs
=
before_op
->
InputNames
();
const
auto
&
after_inputs
=
after_op
->
InputNames
();
EXPECT_EQ
(
before_inputs
.
size
(),
after_inputs
.
size
());
for
(
size_t
in_idx
=
0
;
in_idx
<
before_inputs
.
size
();
++
in_idx
)
{
const
auto
&
before_in_arg
=
before_inputs
[
in_idx
];
const
auto
&
after_in_arg
=
after_inputs
[
in_idx
];
EXPECT_EQ
(
before_in_arg
,
after_in_arg
);
const
auto
&
before_in_vars
=
before_op
->
Input
(
before_in_arg
);
const
auto
&
after_in_vars
=
after_op
->
Input
(
after_in_arg
);
EXPECT_EQ
(
before_in_vars
,
after_in_vars
);
}
}
void
CheckOpOutputsEqual
(
const
OpDesc
*
before_op
,
const
OpDesc
*
after_op
)
{
const
auto
&
before_outputs
=
before_op
->
OutputNames
();
const
auto
&
after_outputs
=
after_op
->
OutputNames
();
EXPECT_EQ
(
before_outputs
.
size
(),
after_outputs
.
size
());
for
(
size_t
out_idx
=
0
;
out_idx
<
before_outputs
.
size
();
++
out_idx
)
{
const
auto
&
before_out_arg
=
before_outputs
[
out_idx
];
const
auto
&
after_out_arg
=
after_outputs
[
out_idx
];
EXPECT_EQ
(
before_out_arg
,
after_out_arg
);
const
auto
&
before_out_vars
=
before_op
->
Output
(
before_out_arg
);
const
auto
&
after_out_vars
=
after_op
->
Output
(
after_out_arg
);
EXPECT_EQ
(
before_out_vars
,
after_out_vars
);
}
}
void
CheckOpAttrsEqual
(
const
OpDesc
*
before_op
,
const
OpDesc
*
after_op
)
{
const
auto
&
before_attrs
=
before_op
->
AttrNames
();
const
auto
&
after_attrs
=
after_op
->
AttrNames
();
EXPECT_EQ
(
before_attrs
.
size
(),
after_attrs
.
size
());
for
(
size_t
attr_idx
=
0
;
attr_idx
<
before_attrs
.
size
();
++
attr_idx
)
{
const
auto
&
before_attr
=
before_attrs
[
attr_idx
];
const
auto
&
after_attr
=
after_attrs
[
attr_idx
];
EXPECT_EQ
(
before_attr
,
after_attr
);
EXPECT_EQ
(
before_op
->
GetAttrType
(
before_attr
),
after_op
->
GetAttrType
(
after_attr
));
}
}
void
CheckBlockOpsEqual
(
const
BlockDesc
&
before_block
,
const
BlockDesc
&
after_block
)
{
EXPECT_EQ
(
before_block
.
OpSize
(),
after_block
.
OpSize
());
// op's order must be the same
for
(
size_t
op_idx
=
0
;
op_idx
<
before_block
.
OpSize
();
++
op_idx
)
{
const
auto
&
before_op
=
before_block
.
Op
(
op_idx
);
const
auto
&
after_op
=
after_block
.
Op
(
op_idx
);
EXPECT_EQ
(
before_op
->
Type
(),
after_op
->
Type
());
// Step4.2.1 : check each op's input
CheckOpInputsEqual
(
before_op
,
after_op
);
// Step4.2.2 : check each op's output
CheckOpOutputsEqual
(
before_op
,
after_op
);
// Step4.2.3 : check each op's attribute
CheckOpAttrsEqual
(
before_op
,
after_op
);
}
}
TEST
(
GraphToProgramPass
,
MultiBlock
)
{
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool
flag_temp
=
FLAGS_convert_all_blocks
;
FLAGS_convert_all_blocks
=
true
;
// Step1: Build a program with multi block
ProgramDesc
before_prog
;
BuildProgramWithMultiBlock
(
&
before_prog
);
// Step2: Convert program into graph
std
::
unique_ptr
<
Graph
>
g
(
new
ir
::
Graph
(
before_prog
));
// Step3 : Convert graph back to program
auto
pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
after_prog
;
pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
&
after_prog
);
pass
->
Apply
(
g
.
get
());
// Step4 : Check tow program equal
EXPECT_EQ
(
before_prog
.
Size
(),
after_prog
.
Size
());
for
(
size_t
block_idx
=
0
;
block_idx
<
before_prog
.
Size
();
++
block_idx
)
{
const
auto
&
before_block
=
before_prog
.
Block
(
block_idx
);
const
auto
&
after_block
=
after_prog
.
Block
(
block_idx
);
EXPECT_EQ
(
before_block
.
ID
(),
after_block
.
ID
());
// Step4.1 : check each block's var
CheckBlockVarsEqual
(
before_block
,
after_block
);
// Step4.2 : check each block's op
CheckBlockOpsEqual
(
before_block
,
after_block
);
}
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks
=
flag_temp
;
}
void
BuildProgramWithScaleLossGrad
(
Graph
*
g
)
{
OpDesc
op1
;
op1
.
SetType
(
"op1"
);
OpDesc
op2
;
op2
.
SetType
(
"op2"
);
OpDesc
op3
;
op3
.
SetType
(
"op3"
);
OpDesc
op4
;
op4
.
SetType
(
"op4"
);
VarDesc
var1
(
"var1"
);
VarDesc
var2
(
"var2"
);
ir
::
Node
*
o1
=
g
->
CreateOpNode
(
&
op1
);
ir
::
Node
*
o2
=
g
->
CreateOpNode
(
&
op2
);
ir
::
Node
*
o3
=
g
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
);
ir
::
Node
*
o4
=
g
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateVarNode
(
&
var1
);
ir
::
Node
*
v2
=
g
->
CreateVarNode
(
&
var2
);
// o1->v1->o2
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
// o3->v1
o3
->
outputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
inputs
.
push_back
(
o3
);
// o4->v2
o4
->
outputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o4
);
}
TEST
(
GraphToProgramPass
,
ReplaceScaleLossGrad
)
{
// Step1: Build a program with multi block
ProgramDesc
before_prog
;
Graph
before_graph
(
before_prog
);
BuildProgramWithScaleLossGrad
(
&
before_graph
);
// Step2 : Convert graph back to program
auto
pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
after_prog
;
pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
&
after_prog
);
pass
->
Apply
(
&
before_graph
);
// Step3 : statistics scale_loss_grad and fill_constant number
int
scale_node_num
=
0
,
fill_node_num
=
0
;
const
auto
&
before_nodes_set
=
before_graph
.
Nodes
();
for
(
const
auto
&
n
:
before_nodes_set
)
{
if
(
n
->
Name
()
==
"scale_loss_grad"
)
{
++
scale_node_num
;
}
else
if
(
n
->
Name
()
==
"fill_constant"
)
{
++
fill_node_num
;
}
}
int
scale_op_num
=
0
,
fill_op_num
=
0
;
const
auto
&
block
=
after_prog
.
Block
(
0
);
for
(
const
auto
&
op
:
block
.
AllOps
())
{
if
(
op
->
Type
()
==
"fill_constant"
)
{
++
fill_op_num
;
}
else
if
(
op
->
Type
()
==
"scale_loss_grad"
)
{
++
scale_op_num
;
}
}
// Check pass OK
EXPECT_EQ
(
scale_op_num
,
0
);
EXPECT_EQ
(
scale_node_num
+
fill_node_num
,
fill_op_num
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
浏览文件 @
167523e7
...
...
@@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant;
class
WhileOpEagerDeletionPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
if
(
!
graph
->
IsMainGraph
())
{
// TODO(zhhsplendid): the WhileOpEagerDeletionPass is based on old Graph,
// which only applies to the main block graph. The new Eager Deletion
// Technical can be added after we write new while_op based on SubGraph
// instead of SubBlock
return
;
}
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
);
// Find all while_op and while_grad_op. In case of @to_static, graph
...
...
@@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
}
if
(
graph
->
IsConstructedByPartialProgram
())
{
VLOG
(
4
)
<<
"Is Paritial Program"
;
PADDLE_ENFORCE_LE
(
target_ops
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
for
(
auto
&
ops_pair
:
target_ops
)
{
VLOG
(
4
)
<<
"Scope Idx = "
<<
ops_pair
.
first
;
auto
&
while_ops
=
ops_pair
.
second
.
first
;
VLOG
(
4
)
<<
"while_ops.size() = "
<<
while_ops
.
size
();
auto
&
while_grad_ops
=
ops_pair
.
second
.
second
;
VLOG
(
4
)
<<
"while_grad_ops.size() = "
<<
while_grad_ops
.
size
();
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
graph
->
OriginProgram
(),
while_ops
,
while_grad_ops
);
}
...
...
paddle/fluid/framework/ir/node.cc
浏览文件 @
167523e7
...
...
@@ -30,7 +30,7 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
}
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
VarDesc
*
var_desc
)
{
return
std
::
unique_ptr
<
Node
>
(
new
Node
(
var_desc
));
return
std
::
unique_ptr
<
Node
>
(
new
Node
(
var_desc
,
0
));
}
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
OpDesc
*
op_desc
)
{
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
167523e7
...
...
@@ -136,9 +136,98 @@ class Node {
var_desc_
->
SetName
(
new_name
);
}
int
DescOrder
()
const
{
return
desc_order_
;
}
int
GetVarNodeBlockId
()
const
{
PADDLE_ENFORCE_EQ
(
type_
==
Type
::
kVariable
&&
var_desc_
,
true
,
platform
::
errors
::
InvalidArgument
(
"Node must be type of variable."
));
return
block_id_
;
}
const
std
::
string
ToString
()
const
{
if
(
IsOp
())
{
std
::
string
op_str
(
Name
());
const
auto
&
op
=
Op
();
if
(
op
==
nullptr
)
{
// Node is an Op but hasn't OpDesc (often create by CreateEmptyNode),
// like ScaleLossGradOp, it's type is OpHandle, which created by Pass
// and then inserted into graph.
// For OpHandle, we have to use Node's input and output for sorting.
std
::
vector
<
Node
*>
sorted_inputs
(
inputs
);
std
::
vector
<
Node
*>
sorted_outputs
(
outputs
);
auto
comparator
=
[](
Node
*
a
,
Node
*
b
)
{
return
a
->
Name
()
>
b
->
Name
();
};
std
::
stable_sort
(
sorted_inputs
.
begin
(),
sorted_inputs
.
end
(),
comparator
);
std
::
stable_sort
(
sorted_outputs
.
begin
(),
sorted_outputs
.
end
(),
comparator
);
std
::
string
out_str
=
"{"
;
std
::
string
pre_str
=
""
;
for
(
const
auto
&
output
:
sorted_outputs
)
{
out_str
.
append
(
pre_str
+
output
->
Name
());
pre_str
=
", "
;
}
out_str
.
append
(
"} = "
);
std
::
string
in_str
=
"("
;
pre_str
=
""
;
for
(
const
auto
&
input
:
sorted_inputs
)
{
in_str
.
append
(
pre_str
+
input
->
Name
());
pre_str
=
", "
;
}
in_str
.
append
(
")"
);
op_str
=
out_str
+
op_str
+
in_str
;
}
else
{
// A normal Op, has OpDesc, create from ProgramDesc
std
::
string
out_str
=
"{"
;
std
::
string
outer_pre_str
=
""
;
for
(
const
auto
&
output
:
op
->
OutputNames
())
{
out_str
.
append
(
outer_pre_str
+
output
+
"=["
);
std
::
string
inner_pre_str
=
""
;
for
(
const
auto
&
arg
:
op
->
Output
(
output
))
{
out_str
.
append
(
inner_pre_str
+
arg
);
inner_pre_str
=
" ,"
;
}
outer_pre_str
=
", "
;
out_str
.
append
(
"]"
);
}
out_str
.
append
(
"} = "
);
std
::
string
in_str
=
"("
;
outer_pre_str
=
""
;
for
(
const
auto
&
input
:
op
->
InputNames
())
{
in_str
.
append
(
outer_pre_str
+
input
+
"=["
);
std
::
string
inner_pre_str
=
""
;
for
(
const
auto
&
arg
:
op
->
Input
(
input
))
{
in_str
.
append
(
inner_pre_str
+
arg
);
inner_pre_str
=
" ,"
;
}
outer_pre_str
=
" ,"
;
in_str
.
append
(
"]"
);
}
in_str
.
append
(
")"
);
op_str
=
out_str
+
op_str
+
in_str
;
}
return
op_str
;
}
return
Name
();
}
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
// Because NO_DESC_ORDER is a constexpr number,
// no one can change it, meanwhile, we need
// check whether the DescOrder invalid sometime,
// so expose it is a good idea
static
constexpr
int
NO_DESC_ORDER
=
INT_MAX
;
protected:
std
::
string
name_
;
std
::
unique_ptr
<
VarDesc
>
var_desc_
;
...
...
@@ -146,30 +235,45 @@ class Node {
Type
type_
;
int
id_
;
int
desc_order_
;
int
block_id_
{
-
1
};
private:
// ID can only set by a Graph.
void
SetId
(
int
id
)
{
id_
=
id
;
}
// desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc.
void
SetDescOrder
(
int
desc_order
)
{
desc_order_
=
desc_order
;
}
friend
class
Graph
;
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
const
std
::
string
&
name
,
Node
::
Type
type
);
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
VarDesc
*
var_desc
);
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
OpDesc
*
op_desc
);
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
explicit
Node
(
const
std
::
string
&
name
,
Type
type
,
int
block_id
=
0
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
),
desc_order_
(
NO_DESC_ORDER
),
block_id_
(
block_id
)
{}
explicit
Node
(
VarDesc
*
var_desc
)
explicit
Node
(
VarDesc
*
var_desc
,
int
block_id
)
:
name_
(
var_desc
->
Name
()),
var_desc_
(
new
VarDesc
(
*
var_desc
)),
op_desc_
(
nullptr
),
type_
(
Type
::
kVariable
)
{}
type_
(
Type
::
kVariable
),
desc_order_
(
NO_DESC_ORDER
),
block_id_
(
block_id
)
{}
explicit
Node
(
OpDesc
*
op_desc
)
:
name_
(
op_desc
->
Type
()),
var_desc_
(
nullptr
),
op_desc_
(
new
OpDesc
(
*
op_desc
,
op_desc
->
Block
())),
type_
(
Type
::
kOperation
)
{}
type_
(
Type
::
kOperation
),
desc_order_
(
NO_DESC_ORDER
)
{}
Node
()
=
delete
;
...
...
paddle/fluid/framework/ir/node_test.cc
浏览文件 @
167523e7
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -75,6 +76,32 @@ TEST(NodeTest, Basic) {
EXPECT_FALSE
(
alive2
);
}
TEST
(
NodeTest
,
ToString
)
{
VarDesc
var_desc
(
"n2"
);
OpDesc
op_desc
;
op_desc
.
SetType
(
"test_op"
);
op_desc
.
SetInput
(
"X"
,
{
"x1"
,
"x2"
,
"x3"
});
op_desc
.
SetOutput
(
"Y"
,
{
"y1"
,
"y2"
});
std
::
unique_ptr
<
Node
>
n1
(
CreateNodeForTest
(
"n1"
,
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
Node
>
n2
(
CreateNodeForTest
(
&
var_desc
));
std
::
unique_ptr
<
Node
>
n3
(
CreateNodeForTest
(
"n3"
,
Node
::
Type
::
kOperation
));
std
::
unique_ptr
<
Node
>
n4
(
CreateNodeForTest
(
&
op_desc
));
EXPECT_EQ
(
n1
->
ToString
(),
"n1"
);
EXPECT_EQ
(
n2
->
ToString
(),
"n2"
);
EXPECT_EQ
(
n3
->
Op
(),
nullptr
);
EXPECT_EQ
(
n3
->
ToString
(),
"{} = n3()"
);
EXPECT_NE
(
n4
->
Op
(),
nullptr
);
EXPECT_EQ
(
n4
->
ToString
(),
"{Y=[y1 ,y2]} = test_op(X=[x1 ,x2 ,x3])"
);
n3
->
inputs
.
push_back
(
n1
.
get
());
n3
->
outputs
.
push_back
(
n2
.
get
());
EXPECT_EQ
(
n3
->
Op
(),
nullptr
);
EXPECT_EQ
(
n3
->
ToString
(),
"{n2} = n3(n1)"
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录