Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
167523e7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
167523e7
编写于
7月 28, 2021
作者:
J
jiangcheng
提交者:
GitHub
7月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
graph_to_program topology sort (#33949)
See
https://github.com/PaddlePaddle/Paddle/pull/33949
for details
上级
f1654de6
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
762 addition
and
33 deletion
+762
-33
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+33
-7
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+27
-5
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+80
-0
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+2
-0
paddle/fluid/framework/ir/graph_to_program_pass.cc
paddle/fluid/framework/ir/graph_to_program_pass.cc
+87
-15
paddle/fluid/framework/ir/graph_to_program_pass.h
paddle/fluid/framework/ir/graph_to_program_pass.h
+3
-0
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
+382
-0
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
...k/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
+11
-0
paddle/fluid/framework/ir/node.cc
paddle/fluid/framework/ir/node.cc
+1
-1
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+109
-5
paddle/fluid/framework/ir/node_test.cc
paddle/fluid/framework/ir/node_test.cc
+27
-0
未找到文件。
paddle/fluid/framework/ir/graph.cc
浏览文件 @
167523e7
...
@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
...
@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
// sub_graph.
// sub_graph.
std
::
unique_ptr
<
Graph
>
first_sub_graph
=
std
::
make_unique
<
Graph
>
(
std
::
unique_ptr
<
Graph
>
first_sub_graph
=
std
::
make_unique
<
Graph
>
(
program_
.
Block
(
0
),
this
,
start_op_index
,
end_op_index
);
program_
.
Block
(
0
),
this
,
start_op_index
,
end_op_index
);
first_sub_graph
->
block_id_
=
0
;
sub_graphs_
.
push_back
(
std
::
move
(
first_sub_graph
));
sub_graphs_
.
push_back
(
std
::
move
(
first_sub_graph
));
for
(
size_t
idx
=
1
;
idx
<
program_
.
Size
();
++
idx
)
{
for
(
size_t
idx
=
1
;
idx
<
program_
.
Size
();
++
idx
)
{
std
::
unique_ptr
<
Graph
>
sub_graph
=
std
::
unique_ptr
<
Graph
>
sub_graph
=
std
::
make_unique
<
Graph
>
(
program_
.
Block
(
idx
),
this
);
std
::
make_unique
<
Graph
>
(
program_
.
Block
(
idx
),
this
);
sub_graph
->
block_id_
=
idx
;
sub_graphs_
.
push_back
(
std
::
move
(
sub_graph
));
sub_graphs_
.
push_back
(
std
::
move
(
sub_graph
));
}
}
}
else
{
}
else
{
...
@@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
...
@@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
Graph
::
InitFromBlock
(
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
Graph
::
InitFromBlock
(
const
BlockDesc
&
block
,
const
int64_t
start_op_index
,
const
BlockDesc
&
block
,
const
int64_t
start_op_index
,
const
int64_t
end_op_index
)
{
const
int64_t
end_op_index
)
{
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarDesc
*
,
int
>>
name_to_desc_block_id
;
const
BlockDesc
*
block_var_visible
=
&
block
;
while
(
block_var_visible
!=
nullptr
)
{
for
(
auto
*
var
:
block_var_visible
->
AllVars
())
{
name_to_desc_block_id
.
emplace
(
var
->
Name
(),
std
::
make_pair
(
var
,
block_var_visible
->
ID
()));
}
const
BlockDesc
*
forward_block
=
block_var_visible
->
ForwardBlock
();
if
(
forward_block
!=
nullptr
)
{
for
(
auto
*
var
:
forward_block
->
AllVars
())
{
name_to_desc_block_id
.
emplace
(
var
->
Name
(),
std
::
make_pair
(
var
,
forward_block
->
ID
()));
}
}
block_var_visible
=
block_var_visible
->
ParentBlock
();
}
// var nodes for each var name, will have multiple versions in SSA
// var nodes for each var name, will have multiple versions in SSA
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes
;
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
not_visited_vars
;
for
(
auto
*
var
:
block
.
AllVars
())
{
for
(
auto
*
var
:
block
.
AllVars
())
{
all
_vars
.
emplace
(
var
->
Name
(),
var
);
not_visited
_vars
.
emplace
(
var
->
Name
(),
var
);
}
}
auto
not_visited_vars
=
all_vars
;
int
desc_order
=
0
;
auto
all_ops
=
block
.
AllOps
();
auto
all_ops
=
block
.
AllOps
();
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
end_op_index
,
all_ops
.
size
(),
end_op_index
,
all_ops
.
size
(),
...
@@ -109,6 +129,8 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
...
@@ -109,6 +129,8 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
auto
*
op
=
all_ops
[
i
];
auto
*
op
=
all_ops
[
i
];
VLOG
(
3
)
<<
"create OpNode by "
<<
op
->
Type
();
VLOG
(
3
)
<<
"create OpNode by "
<<
op
->
Type
();
ir
::
Node
*
node
=
CreateOpNode
(
op
);
ir
::
Node
*
node
=
CreateOpNode
(
op
);
node
->
SetDescOrder
(
desc_order
);
++
desc_order
;
// For input args, reuse the same var name if it was created before.
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
// Otherwise, create a new one.
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
...
@@ -116,8 +138,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
...
@@ -116,8 +138,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
ir
::
Node
*
var
=
nullptr
;
ir
::
Node
*
var
=
nullptr
;
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
var
=
var_nodes
.
at
(
each_var_name
).
back
();
var
=
var_nodes
.
at
(
each_var_name
).
back
();
}
else
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
}
else
if
(
name_to_desc_block_id
.
count
(
each_var_name
)
!=
0
)
{
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
auto
desc_and_block_id
=
name_to_desc_block_id
.
at
(
each_var_name
);
var
=
CreateVarNode
(
desc_and_block_id
.
first
,
desc_and_block_id
.
second
);
var_nodes
[
each_var_name
].
push_back
(
var
);
var_nodes
[
each_var_name
].
push_back
(
var
);
}
else
{
}
else
{
// Operation input var can be optional (dispensable). Which means
// Operation input var can be optional (dispensable). Which means
...
@@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
...
@@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
}
}
ir
::
Node
*
var
=
nullptr
;
ir
::
Node
*
var
=
nullptr
;
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
if
(
name_to_desc_block_id
.
count
(
each_var_name
)
!=
0
)
{
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
auto
desc_and_block_id
=
name_to_desc_block_id
.
at
(
each_var_name
);
var
=
CreateVarNode
(
desc_and_block_id
.
first
,
desc_and_block_id
.
second
);
}
else
{
}
else
{
// Operation output vars can be @EMPTY@. For example, while_grad
// Operation output vars can be @EMPTY@. For example, while_grad
// can have multi @EMPTY@ outputs with no VarDesc.
// can have multi @EMPTY@ outputs with no VarDesc.
...
@@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
...
@@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
auto
cloned_graph
=
std
::
make_shared
<
Graph
>
(
this
->
program_
);
auto
cloned_graph
=
std
::
make_shared
<
Graph
>
(
this
->
program_
);
cloned_graph
->
ReleaseNodes
();
cloned_graph
->
ReleaseNodes
();
cloned_graph
->
num_node_created_
=
0
;
cloned_graph
->
num_node_created_
=
0
;
cloned_graph
->
block_id_
=
this
->
block_id_
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
for
(
auto
*
n
:
this
->
node_set_
)
{
for
(
auto
*
n
:
this
->
node_set_
)
{
PADDLE_ENFORCE_NOT_NULL
(
n
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NOT_NULL
(
n
,
platform
::
errors
::
InvalidArgument
(
...
@@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
...
@@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
std
::
make_unique
<
Graph
>
(
this
->
program_
.
Block
(
idx
),
this
);
std
::
make_unique
<
Graph
>
(
this
->
program_
.
Block
(
idx
),
this
);
cloned_sub_graph
->
ReleaseNodes
();
cloned_sub_graph
->
ReleaseNodes
();
cloned_sub_graph
->
num_node_created_
=
0
;
cloned_sub_graph
->
num_node_created_
=
0
;
cloned_sub_graph
->
block_id_
=
idx
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
for
(
auto
*
n
:
this
->
sub_graphs_
.
at
(
idx
)
->
Nodes
())
{
for
(
auto
*
n
:
this
->
sub_graphs_
.
at
(
idx
)
->
Nodes
())
{
PADDLE_ENFORCE_NOT_NULL
(
n
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NOT_NULL
(
n
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
167523e7
...
@@ -104,7 +104,14 @@ class Graph {
...
@@ -104,7 +104,14 @@ class Graph {
attr_dels_
.
clear
();
attr_dels_
.
clear
();
}
}
bool
IsConstructedByPartialProgram
()
const
{
return
is_partial_
;
}
bool
IsConstructedByPartialProgram
()
const
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
IsMainGraph
())
{
return
GetSubGraph
(
0
)
->
IsConstructedByPartialProgram
();
}
}
return
is_partial_
;
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
FLAGS_convert_all_blocks
)
{
...
@@ -210,7 +217,7 @@ class Graph {
...
@@ -210,7 +217,7 @@ class Graph {
}
}
// Create a normal variable with non-null VarDesc.
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
,
int
block_id
=
-
1
)
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
IsMainGraph
())
{
if
(
IsMainGraph
())
{
return
GetSubGraph
(
0
)
->
CreateVarNode
(
var_desc
);
return
GetSubGraph
(
0
)
->
CreateVarNode
(
var_desc
);
...
@@ -219,7 +226,8 @@ class Graph {
...
@@ -219,7 +226,8 @@ class Graph {
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
var_desc
,
platform
::
errors
::
InvalidArgument
(
var_desc
,
platform
::
errors
::
InvalidArgument
(
"The VarDesc used to create variable node is null."
));
"The VarDesc used to create variable node is null."
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
var_desc
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
var_desc
,
block_id
==
-
1
?
block_id_
:
block_id
));
x
->
SetId
(
num_node_created_
++
);
x
->
SetId
(
num_node_created_
++
);
return
x
;
return
x
;
}
}
...
@@ -252,7 +260,7 @@ class Graph {
...
@@ -252,7 +260,7 @@ class Graph {
const
std
::
string
name
=
string
::
Sprintf
(
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
static_cast
<
const
char
*>
(
ir
::
Node
::
kControlDepVarName
),
"%s@%llu"
,
static_cast
<
const
char
*>
(
ir
::
Node
::
kControlDepVarName
),
num_node_created_
);
num_node_created_
);
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
,
block_id_
));
x
->
SetId
(
num_node_created_
++
);
x
->
SetId
(
num_node_created_
++
);
return
x
;
return
x
;
}
}
...
@@ -265,7 +273,7 @@ class Graph {
...
@@ -265,7 +273,7 @@ class Graph {
return
GetSubGraph
(
0
)
->
CreateEmptyNode
(
name
,
type
);
return
GetSubGraph
(
0
)
->
CreateEmptyNode
(
name
,
type
);
}
}
}
}
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
type
));
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
type
,
block_id_
));
x
->
SetId
(
num_node_created_
++
);
x
->
SetId
(
num_node_created_
++
);
return
x
;
return
x
;
}
}
...
@@ -365,6 +373,15 @@ class Graph {
...
@@ -365,6 +373,15 @@ class Graph {
return
sub_graphs_
.
at
(
idx
).
get
();
return
sub_graphs_
.
at
(
idx
).
get
();
}
}
int
GetBlockId
()
const
{
if
(
FLAGS_convert_all_blocks
)
{
if
(
IsMainGraph
())
{
return
GetSubGraph
(
0
)
->
block_id_
;
}
}
return
block_id_
;
}
size_t
SubGraphsSize
()
const
{
size_t
SubGraphsSize
()
const
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
this
->
IsMainGraph
(),
true
,
this
->
IsMainGraph
(),
true
,
...
@@ -394,6 +411,9 @@ class Graph {
...
@@ -394,6 +411,9 @@ class Graph {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
this
->
IsMainGraph
(),
true
,
this
->
IsMainGraph
(),
true
,
platform
::
errors
::
InvalidArgument
(
"This graph is not main_graph"
));
platform
::
errors
::
InvalidArgument
(
"This graph is not main_graph"
));
PADDLE_ENFORCE_EQ
(
sub_graphs_
.
size
(),
sub_graph
->
block_id_
,
platform
::
errors
::
InvalidArgument
(
"sub_graph idx is not equal to block_id_"
));
sub_graphs_
.
push_back
(
std
::
move
(
sub_graph
));
sub_graphs_
.
push_back
(
std
::
move
(
sub_graph
));
}
}
...
@@ -416,6 +436,8 @@ class Graph {
...
@@ -416,6 +436,8 @@ class Graph {
// parts: forward graph and backward graph, which can be executed
// parts: forward graph and backward graph, which can be executed
// independently.
// independently.
bool
is_partial_
{
false
};
bool
is_partial_
{
false
};
// The block this SubGraph belongs to.
int
block_id_
{
0
};
};
};
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
167523e7
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include <queue>
#include <stack>
#include <stack>
DEFINE_string
(
print_sub_graph_dir
,
""
,
DEFINE_string
(
print_sub_graph_dir
,
""
,
...
@@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
...
@@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
}
}
}
}
class
DescOrderComparator
{
public:
bool
operator
()(
const
Node
*
n1
,
const
Node
*
n2
)
{
return
(
n1
->
DescOrder
()
>
n2
->
DescOrder
())
||
((
n1
->
DescOrder
()
==
n2
->
DescOrder
())
&&
(
n1
->
ToString
()
>
n2
->
ToString
()));
}
};
std
::
vector
<
ir
::
Node
*>
TopologySortGraphByDescOrder
(
const
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
std
::
priority_queue
<
Node
*
,
std
::
vector
<
Node
*>
,
DescOrderComparator
>
q
;
std
::
unordered_map
<
Node
*
,
std
::
unordered_set
<
Node
*>>
in_ops
;
std
::
unordered_map
<
Node
*
,
std
::
unordered_set
<
Node
*>>
out_ops
;
// ensure all op node in 'in_ops' and 'out_ops'
for
(
const
auto
&
n
:
graph
.
Nodes
())
{
if
(
!
n
->
IsOp
())
continue
;
in_ops
.
emplace
(
n
,
std
::
unordered_set
<
Node
*>
());
out_ops
.
emplace
(
n
,
std
::
unordered_set
<
Node
*>
());
}
// record all op's input op and output op
for
(
const
auto
&
n
:
graph
.
Nodes
())
{
if
(
!
n
->
IsOp
())
continue
;
// traverse all input op
for
(
const
auto
&
var
:
n
->
inputs
)
{
for
(
const
auto
&
in
:
var
->
inputs
)
{
// use at instead of [] to prevent no unrecorded op node
in_ops
.
at
(
n
).
insert
(
in
);
out_ops
.
at
(
in
).
insert
(
n
);
}
}
}
// find topology entrance
for
(
const
auto
&
n
:
graph
.
Nodes
())
{
if
(
!
n
->
IsOp
())
continue
;
if
(
in_ops
.
at
(
n
).
empty
())
{
q
.
push
(
n
);
}
}
// topological sorting
while
(
!
q
.
empty
())
{
// Do not get by reference!!! The element will pop later.
const
auto
cur_op
=
q
.
top
();
q
.
pop
();
sorted_ops
.
push_back
(
cur_op
);
for
(
const
auto
&
out
:
out_ops
.
at
(
cur_op
))
{
PADDLE_ENFORCE_GT
(
in_ops
.
at
(
out
).
count
(
cur_op
),
0
,
platform
::
errors
::
InvalidArgument
(
"We find %s in %s's output list, "
"but cannot find %s in %s's input list. "
"Please ensure graph completely."
,
out
->
Name
().
c_str
(),
cur_op
->
Name
().
c_str
(),
cur_op
->
Name
().
c_str
(),
out
->
Name
().
c_str
()));
in_ops
.
at
(
out
).
erase
(
cur_op
);
// push if in-degree is 0
if
(
in_ops
.
at
(
out
).
empty
())
{
q
.
push
(
out
);
}
}
}
PADDLE_ENFORCE_EQ
(
sorted_ops
.
size
(),
in_ops
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Topological sorting incompletely, "
"only sorted %zd op but total %zd."
,
sorted_ops
.
size
(),
in_ops
.
size
()));
return
sorted_ops
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
167523e7
...
@@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
...
@@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
return
ret
;
return
ret
;
}
}
std
::
vector
<
ir
::
Node
*>
TopologySortGraphByDescOrder
(
const
Graph
&
graph
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_to_program_pass.cc
浏览文件 @
167523e7
...
@@ -14,7 +14,13 @@ limitations under the License. */
...
@@ -14,7 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <gflags/gflags.h>
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool
(
convert_all_blocks
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -27,13 +33,10 @@ namespace framework {
...
@@ -27,13 +33,10 @@ namespace framework {
namespace
ir
{
namespace
ir
{
void
GraphToProgramPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
GraphToProgramPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
// Remove the unneeded variables after memory optimization.
PADDLE_ENFORCE_EQ
(
graph
->
IsMainGraph
(),
true
,
std
::
unordered_set
<
std
::
string
>
vars2remove
;
platform
::
errors
::
InvalidArgument
(
if
(
graph
->
Has
(
kGraphToProgramVarsToRemove
))
{
"This graph is a sub_graph, "
vars2remove
=
graph
->
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"and can't convert to program individually"
));
kGraphToProgramVarsToRemove
);
VLOG
(
2
)
<<
"graph to program remove "
<<
vars2remove
.
size
()
<<
" nodes"
;
}
ProgramDesc
&
program
=
Get
<
ProgramDesc
>
(
"program"
);
ProgramDesc
&
program
=
Get
<
ProgramDesc
>
(
"program"
);
...
@@ -42,12 +45,79 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -42,12 +45,79 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
auto
block
=
program_pb
->
mutable_blocks
(
kRootBlockIndex
);
auto
block
=
program_pb
->
mutable_blocks
(
kRootBlockIndex
);
block
->
set_idx
(
kRootBlockIndex
);
block
->
set_idx
(
kRootBlockIndex
);
if
(
FLAGS_convert_all_blocks
)
{
GraphToBlock
(
graph
->
GetSubGraph
(
kRootBlockIndex
),
block
);
VLOG
(
3
)
<<
"Graph to program need convert "
<<
graph
->
SubGraphsSize
()
<<
" sub graph"
;
for
(
size_t
idx
=
0
;
idx
<
graph
->
SubGraphsSize
();
++
idx
)
{
// avoid kRootBlockIndex not 0
if
(
idx
==
kRootBlockIndex
)
continue
;
block
=
program_pb
->
add_blocks
();
block
->
set_idx
(
idx
);
GraphToBlock
(
graph
->
GetSubGraph
(
idx
),
block
);
}
}
else
{
GraphToBlock
(
graph
,
block
);
}
program
.
CopyFrom
(
*
program_pb
);
}
OpDesc
*
ReplaceScaleLossGradOp
(
ir
::
Node
*
node
,
OpDesc
*
desc
)
{
desc
->
SetType
(
"fill_constant"
);
desc
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)));
desc
->
SetAttr
(
"value"
,
1.0
f
);
std
::
vector
<
std
::
string
>
output_names
;
for
(
auto
out
:
node
->
outputs
)
{
output_names
.
emplace_back
(
out
->
Name
());
}
desc
->
SetOutput
(
"Out"
,
output_names
);
return
desc
;
}
std
::
vector
<
OpDesc
>*
GetGraphOpDesc
(
const
std
::
vector
<
ir
::
Node
*>&
nodes
,
std
::
vector
<
OpDesc
>*
ops
)
{
for
(
ir
::
Node
*
n
:
nodes
)
{
// if node is not Op, skip
if
(
!
n
->
IsOp
())
continue
;
// create fill_constant op
if
(
n
->
Name
()
==
"scale_loss_grad"
)
{
ops
->
emplace_back
();
auto
&
desc
=
ops
->
back
();
ReplaceScaleLossGradOp
(
n
,
&
desc
);
}
else
if
(
n
->
Op
())
{
ops
->
emplace_back
(
*
n
->
Op
());
}
else
{
// delete no OpDesc op
}
}
return
ops
;
}
void
GraphToProgramPass
::
GraphToBlock
(
const
Graph
*
graph
,
proto
::
BlockDesc
*
block
)
const
{
// Remove the unneeded variables after memory optimization.
std
::
unordered_set
<
std
::
string
>
vars2remove
;
if
(
graph
->
Has
(
kGraphToProgramVarsToRemove
))
{
vars2remove
=
graph
->
Get
<
std
::
unordered_set
<
std
::
string
>>
(
kGraphToProgramVarsToRemove
);
VLOG
(
2
)
<<
"graph (id: "
<<
block
->
idx
()
<<
") to program remove "
<<
vars2remove
.
size
()
<<
" nodes"
;
}
block
->
clear_vars
();
block
->
clear_vars
();
std
::
unordered_set
<
std
::
string
>
visited_vars
;
std
::
unordered_set
<
std
::
string
>
visited_vars
;
for
(
ir
::
Node
*
n
:
graph
->
Nodes
())
{
for
(
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
IsVar
())
{
if
(
n
->
IsVar
())
{
if
(
n
->
Var
()
&&
visited_vars
.
count
(
n
->
Var
()
->
Name
())
==
0
&&
if
(
n
->
Var
()
&&
visited_vars
.
count
(
n
->
Var
()
->
Name
())
==
0
&&
!
vars2remove
.
count
(
n
->
Var
()
->
Name
()))
{
!
vars2remove
.
count
(
n
->
Var
()
->
Name
())
&&
n
->
GetVarNodeBlockId
()
==
graph
->
GetBlockId
())
{
visited_vars
.
insert
(
n
->
Var
()
->
Name
());
visited_vars
.
insert
(
n
->
Var
()
->
Name
());
block
->
add_vars
()
->
MergeFrom
(
*
n
->
Var
()
->
Proto
());
block
->
add_vars
()
->
MergeFrom
(
*
n
->
Var
()
->
Proto
());
}
}
...
@@ -61,17 +131,19 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -61,17 +131,19 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
int
sort_kind
=
Get
<
int
>
(
kGraphToProgramSortKind
);
int
sort_kind
=
Get
<
int
>
(
kGraphToProgramSortKind
);
nodes
=
TopologyVarientSort
(
nodes
=
TopologyVarientSort
(
*
graph
,
static_cast
<
framework
::
ir
::
SortKind
>
(
sort_kind
));
*
graph
,
static_cast
<
framework
::
ir
::
SortKind
>
(
sort_kind
));
}
else
{
if
(
FLAGS_convert_all_blocks
)
{
nodes
=
TopologySortGraphByDescOrder
(
*
graph
);
}
else
{
}
else
{
nodes
=
TopologySortOperations
(
*
graph
);
nodes
=
TopologySortOperations
(
*
graph
);
}
}
for
(
ir
::
Node
*
n
:
nodes
)
{
if
(
!
n
->
Op
())
continue
;
block
->
add_ops
()
->
MergeFrom
(
*
n
->
Op
()
->
Proto
());
}
}
program
.
CopyFrom
(
*
program_pb
);
std
::
vector
<
OpDesc
>
ops
;
GetGraphOpDesc
(
nodes
,
&
ops
);
for
(
auto
&
op
:
ops
)
{
block
->
add_ops
()
->
MergeFrom
(
*
op
.
Proto
());
}
}
}
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_to_program_pass.h
浏览文件 @
167523e7
...
@@ -29,6 +29,9 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
...
@@ -29,6 +29,9 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class
GraphToProgramPass
:
public
Pass
{
class
GraphToProgramPass
:
public
Pass
{
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
GraphToBlock
(
const
Graph
*
graph
,
proto
::
BlockDesc
*
block
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_to_program_pass_test.cc
浏览文件 @
167523e7
...
@@ -14,8 +14,14 @@ limitations under the License. */
...
@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <algorithm>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -103,6 +109,382 @@ TEST(GraphToProgramPass, Basic) {
...
@@ -103,6 +109,382 @@ TEST(GraphToProgramPass, Basic) {
EXPECT_TRUE
(
vars
.
find
(
"var2"
)
!=
vars
.
end
());
EXPECT_TRUE
(
vars
.
find
(
"var2"
)
!=
vars
.
end
());
EXPECT_TRUE
(
vars
.
find
(
"var3"
)
!=
vars
.
end
());
EXPECT_TRUE
(
vars
.
find
(
"var3"
)
!=
vars
.
end
());
}
}
void
BuildProgramWithMultiBlock
(
ProgramDesc
*
program
)
{
auto
*
global_block
=
program
->
MutableBlock
(
0
);
auto
*
mul_1_x
=
global_block
->
Var
(
"Mul_1_X"
);
mul_1_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_1_x
->
SetLoDLevel
(
0
);
mul_1_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_1_x
->
SetShape
({
1000
,
784
});
auto
*
mul_1_y
=
global_block
->
Var
(
"Mul_1_Y"
);
mul_1_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_1_y
->
SetLoDLevel
(
0
);
mul_1_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_1_y
->
SetShape
({
784
,
100
});
auto
*
mul_1_out
=
global_block
->
Var
(
"Mul_1_Out"
);
mul_1_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
auto
*
mul_op_1
=
global_block
->
AppendOp
();
mul_op_1
->
SetType
(
"mul"
);
mul_op_1
->
SetInput
(
"X"
,
{
mul_1_x
->
Name
()});
mul_op_1
->
SetInput
(
"Y"
,
{
mul_1_y
->
Name
()});
mul_op_1
->
SetOutput
(
"Y"
,
{
mul_1_out
->
Name
()});
// building cond op such as less_than
auto
*
less_than_op_1
=
global_block
->
AppendOp
();
less_than_op_1
->
SetType
(
"less_than"
);
auto
*
less_than_1_x
=
global_block
->
Var
(
"Less_than_1_X"
);
less_than_1_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_1_x
->
SetLoDLevel
(
0
);
less_than_1_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_1_x
->
SetShape
({
1
});
auto
*
less_than_1_y
=
global_block
->
Var
(
"Less_than_1_Y"
);
less_than_1_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_1_y
->
SetLoDLevel
(
0
);
less_than_1_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_1_y
->
SetShape
({
1
});
auto
*
less_than_1_out
=
global_block
->
Var
(
"Less_than_1_Out"
);
less_than_1_out
->
SetType
(
proto
::
VarType
::
BOOL
);
less_than_op_1
->
SetInput
(
"X"
,
{
less_than_1_x
->
Name
()});
less_than_op_1
->
SetInput
(
"Y"
,
{
less_than_1_y
->
Name
()});
less_than_op_1
->
SetOutput
(
"Out"
,
{
less_than_1_out
->
Name
()});
BlockDesc
*
sub_block
=
program
->
AppendBlock
(
*
global_block
);
std
::
vector
<
BlockDesc
*>
sub_blocks
;
sub_blocks
.
push_back
(
sub_block
);
BlockDesc
*
sub_block2
=
program
->
AppendBlock
(
*
sub_block
);
// for testing nested case.
sub_blocks
.
push_back
(
sub_block2
);
// building while op in sub_block
auto
*
while_op
=
global_block
->
AppendOp
();
while_op
->
SetType
(
"while"
);
while_op
->
SetAttr
(
"sub_block"
,
sub_blocks
[
0
]);
auto
*
while_x
=
global_block
->
Var
(
"While_X"
);
while_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
while_x
->
SetLoDLevel
(
0
);
while_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
while_x
->
SetShape
({
1
});
while_op
->
SetInput
(
"kX"
,
{
while_x
->
Name
()});
while_op
->
SetInput
(
"kCondition"
,
{
less_than_1_out
->
Name
()});
auto
*
while_out
=
global_block
->
Var
(
"While_Out"
);
while_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
while_out
->
SetLoDLevel
(
0
);
while_out
->
SetDataType
(
proto
::
VarType
::
FP32
);
while_out
->
SetShape
({
1
});
auto
*
steps
=
global_block
->
Var
(
"StepScopes"
);
while_op
->
SetOutput
(
"kOutputs"
,
{
while_out
->
Name
()});
while_op
->
SetOutput
(
"kStepScopes"
,
{
steps
->
Name
()});
auto
*
mul_2_x
=
global_block
->
Var
(
"Mul_2_X"
);
mul_2_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_2_x
->
SetLoDLevel
(
0
);
mul_2_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_2_x
->
SetShape
({
1000
,
784
});
auto
*
mul_2_y
=
global_block
->
Var
(
"Mul_2_Y"
);
mul_2_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_2_y
->
SetLoDLevel
(
0
);
mul_2_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_2_y
->
SetShape
({
784
,
100
});
auto
*
mul_op_2
=
sub_blocks
[
0
]
->
AppendOp
();
mul_op_2
->
SetType
(
"mul"
);
mul_op_2
->
SetInput
(
"X"
,
{
mul_2_x
->
Name
()});
mul_op_2
->
SetInput
(
"Y"
,
{
mul_2_y
->
Name
()});
auto
*
mul_2_out
=
global_block
->
Var
(
"Mul_2_Out"
);
mul_2_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_op_2
->
SetOutput
(
"Y"
,
{
mul_2_out
->
Name
()});
auto
*
less_than_op_2
=
sub_blocks
[
0
]
->
AppendOp
();
less_than_op_2
->
SetType
(
"less_than"
);
auto
*
less_than_2_x
=
global_block
->
Var
(
"Less_than_2_X"
);
less_than_2_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_2_x
->
SetLoDLevel
(
0
);
less_than_2_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_2_x
->
SetShape
({
1
});
auto
*
less_than_2_y
=
global_block
->
Var
(
"Less_than_2_Y"
);
less_than_2_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
less_than_2_y
->
SetLoDLevel
(
0
);
less_than_2_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
less_than_2_y
->
SetShape
({
1
});
less_than_op_2
->
SetInput
(
"X"
,
{
less_than_2_x
->
Name
()});
less_than_op_2
->
SetInput
(
"Y"
,
{
less_than_2_y
->
Name
()});
auto
*
less_than_2_out
=
global_block
->
Var
(
"Less_than_2_Out"
);
less_than_2_out
->
SetType
(
proto
::
VarType
::
BOOL
);
less_than_op_2
->
SetOutput
(
"Out"
,
{
less_than_2_out
->
Name
()});
auto
*
cond_op
=
sub_blocks
[
0
]
->
AppendOp
();
cond_op
->
SetType
(
"conditional_block"
);
cond_op
->
SetAttr
(
"sub_block"
,
sub_blocks
[
1
]);
auto
*
cond_x
=
sub_blocks
[
0
]
->
Var
(
"Cond_X"
);
cond_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
cond_x
->
SetLoDLevel
(
0
);
cond_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
cond_x
->
SetShape
({
1
});
cond_op
->
SetInput
(
"kInputs"
,
{
cond_x
->
Name
()});
cond_op
->
SetInput
(
"kCondition"
,
{
less_than_2_out
->
Name
()});
auto
*
cond_out
=
sub_blocks
[
0
]
->
Var
(
"Cond_Out"
);
cond_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
cond_out
->
SetLoDLevel
(
0
);
cond_out
->
SetDataType
(
proto
::
VarType
::
FP32
);
cond_out
->
SetShape
({
1
});
auto
*
scope
=
sub_blocks
[
0
]
->
Var
(
"Scope"
);
scope
->
SetType
(
proto
::
VarType
::
STEP_SCOPES
);
cond_op
->
SetOutput
(
"kOutputs"
,
{
cond_out
->
Name
()});
cond_op
->
SetOutput
(
"kScope"
,
{
scope
->
Name
()});
auto
*
mul_3_x
=
global_block
->
Var
(
"Mul_3_X"
);
mul_3_x
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_3_x
->
SetLoDLevel
(
0
);
mul_3_x
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_3_x
->
SetShape
({
1000
,
784
});
auto
*
mul_3_y
=
global_block
->
Var
(
"Mul_3_Y"
);
mul_3_y
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
mul_3_y
->
SetLoDLevel
(
0
);
mul_3_y
->
SetDataType
(
proto
::
VarType
::
FP32
);
mul_3_y
->
SetShape
({
784
,
100
});
auto
*
mul_3_out
=
global_block
->
Var
(
"Mul_3_Out"
);
mul_3_out
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
auto
*
mul_op_3
=
sub_blocks
[
1
]
->
AppendOp
();
mul_op_3
->
SetType
(
"mul"
);
mul_op_3
->
SetInput
(
"X"
,
{
mul_3_x
->
Name
()});
mul_op_3
->
SetInput
(
"Y"
,
{
mul_3_y
->
Name
()});
mul_op_3
->
SetOutput
(
"Y"
,
{
mul_3_out
->
Name
()});
}
bool
VarComparator
(
const
VarDesc
*
a
,
const
VarDesc
*
b
)
{
return
a
->
Name
()
<
b
->
Name
();
}
void
CheckBlockVarsEqual
(
const
BlockDesc
&
before_block
,
const
BlockDesc
&
after_block
)
{
auto
before_vars
=
before_block
.
AllVars
();
auto
after_vars
=
after_block
.
AllVars
();
EXPECT_EQ
(
before_vars
.
size
(),
after_vars
.
size
());
// var's order is unimportant
std
::
sort
(
before_vars
.
begin
(),
before_vars
.
end
(),
VarComparator
);
std
::
sort
(
after_vars
.
begin
(),
after_vars
.
end
(),
VarComparator
);
for
(
size_t
var_idx
=
0
;
var_idx
<
before_vars
.
size
();
++
var_idx
)
{
const
auto
&
before_var
=
before_vars
.
at
(
var_idx
);
const
auto
&
after_var
=
after_vars
.
at
(
var_idx
);
EXPECT_EQ
(
before_var
->
Name
(),
after_var
->
Name
());
EXPECT_EQ
(
before_var
->
GetType
(),
after_var
->
GetType
());
}
}
void
CheckOpInputsEqual
(
const
OpDesc
*
before_op
,
const
OpDesc
*
after_op
)
{
const
auto
&
before_inputs
=
before_op
->
InputNames
();
const
auto
&
after_inputs
=
after_op
->
InputNames
();
EXPECT_EQ
(
before_inputs
.
size
(),
after_inputs
.
size
());
for
(
size_t
in_idx
=
0
;
in_idx
<
before_inputs
.
size
();
++
in_idx
)
{
const
auto
&
before_in_arg
=
before_inputs
[
in_idx
];
const
auto
&
after_in_arg
=
after_inputs
[
in_idx
];
EXPECT_EQ
(
before_in_arg
,
after_in_arg
);
const
auto
&
before_in_vars
=
before_op
->
Input
(
before_in_arg
);
const
auto
&
after_in_vars
=
after_op
->
Input
(
after_in_arg
);
EXPECT_EQ
(
before_in_vars
,
after_in_vars
);
}
}
void
CheckOpOutputsEqual
(
const
OpDesc
*
before_op
,
const
OpDesc
*
after_op
)
{
const
auto
&
before_outputs
=
before_op
->
OutputNames
();
const
auto
&
after_outputs
=
after_op
->
OutputNames
();
EXPECT_EQ
(
before_outputs
.
size
(),
after_outputs
.
size
());
for
(
size_t
out_idx
=
0
;
out_idx
<
before_outputs
.
size
();
++
out_idx
)
{
const
auto
&
before_out_arg
=
before_outputs
[
out_idx
];
const
auto
&
after_out_arg
=
after_outputs
[
out_idx
];
EXPECT_EQ
(
before_out_arg
,
after_out_arg
);
const
auto
&
before_out_vars
=
before_op
->
Output
(
before_out_arg
);
const
auto
&
after_out_vars
=
after_op
->
Output
(
after_out_arg
);
EXPECT_EQ
(
before_out_vars
,
after_out_vars
);
}
}
void
CheckOpAttrsEqual
(
const
OpDesc
*
before_op
,
const
OpDesc
*
after_op
)
{
const
auto
&
before_attrs
=
before_op
->
AttrNames
();
const
auto
&
after_attrs
=
after_op
->
AttrNames
();
EXPECT_EQ
(
before_attrs
.
size
(),
after_attrs
.
size
());
for
(
size_t
attr_idx
=
0
;
attr_idx
<
before_attrs
.
size
();
++
attr_idx
)
{
const
auto
&
before_attr
=
before_attrs
[
attr_idx
];
const
auto
&
after_attr
=
after_attrs
[
attr_idx
];
EXPECT_EQ
(
before_attr
,
after_attr
);
EXPECT_EQ
(
before_op
->
GetAttrType
(
before_attr
),
after_op
->
GetAttrType
(
after_attr
));
}
}
void
CheckBlockOpsEqual
(
const
BlockDesc
&
before_block
,
const
BlockDesc
&
after_block
)
{
EXPECT_EQ
(
before_block
.
OpSize
(),
after_block
.
OpSize
());
// op's order must be the same
for
(
size_t
op_idx
=
0
;
op_idx
<
before_block
.
OpSize
();
++
op_idx
)
{
const
auto
&
before_op
=
before_block
.
Op
(
op_idx
);
const
auto
&
after_op
=
after_block
.
Op
(
op_idx
);
EXPECT_EQ
(
before_op
->
Type
(),
after_op
->
Type
());
// Step4.2.1 : check each op's input
CheckOpInputsEqual
(
before_op
,
after_op
);
// Step4.2.2 : check each op's output
CheckOpOutputsEqual
(
before_op
,
after_op
);
// Step4.2.3 : check each op's attribute
CheckOpAttrsEqual
(
before_op
,
after_op
);
}
}
TEST
(
GraphToProgramPass
,
MultiBlock
)
{
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool
flag_temp
=
FLAGS_convert_all_blocks
;
FLAGS_convert_all_blocks
=
true
;
// Step1: Build a program with multi block
ProgramDesc
before_prog
;
BuildProgramWithMultiBlock
(
&
before_prog
);
// Step2: Convert program into graph
std
::
unique_ptr
<
Graph
>
g
(
new
ir
::
Graph
(
before_prog
));
// Step3 : Convert graph back to program
auto
pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
after_prog
;
pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
&
after_prog
);
pass
->
Apply
(
g
.
get
());
// Step4 : Check tow program equal
EXPECT_EQ
(
before_prog
.
Size
(),
after_prog
.
Size
());
for
(
size_t
block_idx
=
0
;
block_idx
<
before_prog
.
Size
();
++
block_idx
)
{
const
auto
&
before_block
=
before_prog
.
Block
(
block_idx
);
const
auto
&
after_block
=
after_prog
.
Block
(
block_idx
);
EXPECT_EQ
(
before_block
.
ID
(),
after_block
.
ID
());
// Step4.1 : check each block's var
CheckBlockVarsEqual
(
before_block
,
after_block
);
// Step4.2 : check each block's op
CheckBlockOpsEqual
(
before_block
,
after_block
);
}
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks
=
flag_temp
;
}
void
BuildProgramWithScaleLossGrad
(
Graph
*
g
)
{
OpDesc
op1
;
op1
.
SetType
(
"op1"
);
OpDesc
op2
;
op2
.
SetType
(
"op2"
);
OpDesc
op3
;
op3
.
SetType
(
"op3"
);
OpDesc
op4
;
op4
.
SetType
(
"op4"
);
VarDesc
var1
(
"var1"
);
VarDesc
var2
(
"var2"
);
ir
::
Node
*
o1
=
g
->
CreateOpNode
(
&
op1
);
ir
::
Node
*
o2
=
g
->
CreateOpNode
(
&
op2
);
ir
::
Node
*
o3
=
g
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
);
ir
::
Node
*
o4
=
g
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateVarNode
(
&
var1
);
ir
::
Node
*
v2
=
g
->
CreateVarNode
(
&
var2
);
// o1->v1->o2
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
// o3->v1
o3
->
outputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
inputs
.
push_back
(
o3
);
// o4->v2
o4
->
outputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o4
);
}
TEST
(
GraphToProgramPass
,
ReplaceScaleLossGrad
)
{
// Step1: Build a program with multi block
ProgramDesc
before_prog
;
Graph
before_graph
(
before_prog
);
BuildProgramWithScaleLossGrad
(
&
before_graph
);
// Step2 : Convert graph back to program
auto
pass
=
paddle
::
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
after_prog
;
pass
->
SetNotOwned
<
paddle
::
framework
::
ProgramDesc
>
(
"program"
,
&
after_prog
);
pass
->
Apply
(
&
before_graph
);
// Step3 : statistics scale_loss_grad and fill_constant number
int
scale_node_num
=
0
,
fill_node_num
=
0
;
const
auto
&
before_nodes_set
=
before_graph
.
Nodes
();
for
(
const
auto
&
n
:
before_nodes_set
)
{
if
(
n
->
Name
()
==
"scale_loss_grad"
)
{
++
scale_node_num
;
}
else
if
(
n
->
Name
()
==
"fill_constant"
)
{
++
fill_node_num
;
}
}
int
scale_op_num
=
0
,
fill_op_num
=
0
;
const
auto
&
block
=
after_prog
.
Block
(
0
);
for
(
const
auto
&
op
:
block
.
AllOps
())
{
if
(
op
->
Type
()
==
"fill_constant"
)
{
++
fill_op_num
;
}
else
if
(
op
->
Type
()
==
"scale_loss_grad"
)
{
++
scale_op_num
;
}
}
// Check pass OK
EXPECT_EQ
(
scale_op_num
,
0
);
EXPECT_EQ
(
scale_node_num
+
fill_node_num
,
fill_op_num
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
浏览文件 @
167523e7
...
@@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant;
...
@@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant;
class
WhileOpEagerDeletionPass
:
public
ir
::
Pass
{
class
WhileOpEagerDeletionPass
:
public
ir
::
Pass
{
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
if
(
!
graph
->
IsMainGraph
())
{
// TODO(zhhsplendid): the WhileOpEagerDeletionPass is based on old Graph,
// which only applies to the main block graph. The new Eager Deletion
// Technical can be added after we write new while_op based on SubGraph
// instead of SubBlock
return
;
}
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
);
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
);
// Find all while_op and while_grad_op. In case of @to_static, graph
// Find all while_op and while_grad_op. In case of @to_static, graph
...
@@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
...
@@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
}
}
}
if
(
graph
->
IsConstructedByPartialProgram
())
{
if
(
graph
->
IsConstructedByPartialProgram
())
{
VLOG
(
4
)
<<
"Is Paritial Program"
;
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
target_ops
.
size
(),
1
,
target_ops
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass {
...
@@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
}
for
(
auto
&
ops_pair
:
target_ops
)
{
for
(
auto
&
ops_pair
:
target_ops
)
{
VLOG
(
4
)
<<
"Scope Idx = "
<<
ops_pair
.
first
;
auto
&
while_ops
=
ops_pair
.
second
.
first
;
auto
&
while_ops
=
ops_pair
.
second
.
first
;
VLOG
(
4
)
<<
"while_ops.size() = "
<<
while_ops
.
size
();
auto
&
while_grad_ops
=
ops_pair
.
second
.
second
;
auto
&
while_grad_ops
=
ops_pair
.
second
.
second
;
VLOG
(
4
)
<<
"while_grad_ops.size() = "
<<
while_grad_ops
.
size
();
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
graph
->
OriginProgram
(),
while_ops
,
while_grad_ops
);
graph
->
OriginProgram
(),
while_ops
,
while_grad_ops
);
}
}
...
...
paddle/fluid/framework/ir/node.cc
浏览文件 @
167523e7
...
@@ -30,7 +30,7 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
...
@@ -30,7 +30,7 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
}
}
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
VarDesc
*
var_desc
)
{
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
VarDesc
*
var_desc
)
{
return
std
::
unique_ptr
<
Node
>
(
new
Node
(
var_desc
));
return
std
::
unique_ptr
<
Node
>
(
new
Node
(
var_desc
,
0
));
}
}
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
OpDesc
*
op_desc
)
{
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
OpDesc
*
op_desc
)
{
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
167523e7
...
@@ -136,9 +136,98 @@ class Node {
...
@@ -136,9 +136,98 @@ class Node {
var_desc_
->
SetName
(
new_name
);
var_desc_
->
SetName
(
new_name
);
}
}
int
DescOrder
()
const
{
return
desc_order_
;
}
int
GetVarNodeBlockId
()
const
{
PADDLE_ENFORCE_EQ
(
type_
==
Type
::
kVariable
&&
var_desc_
,
true
,
platform
::
errors
::
InvalidArgument
(
"Node must be type of variable."
));
return
block_id_
;
}
const
std
::
string
ToString
()
const
{
if
(
IsOp
())
{
std
::
string
op_str
(
Name
());
const
auto
&
op
=
Op
();
if
(
op
==
nullptr
)
{
// Node is an Op but hasn't OpDesc (often create by CreateEmptyNode),
// like ScaleLossGradOp, it's type is OpHandle, which created by Pass
// and then inserted into graph.
// For OpHandle, we have to use Node's input and output for sorting.
std
::
vector
<
Node
*>
sorted_inputs
(
inputs
);
std
::
vector
<
Node
*>
sorted_outputs
(
outputs
);
auto
comparator
=
[](
Node
*
a
,
Node
*
b
)
{
return
a
->
Name
()
>
b
->
Name
();
};
std
::
stable_sort
(
sorted_inputs
.
begin
(),
sorted_inputs
.
end
(),
comparator
);
std
::
stable_sort
(
sorted_outputs
.
begin
(),
sorted_outputs
.
end
(),
comparator
);
std
::
string
out_str
=
"{"
;
std
::
string
pre_str
=
""
;
for
(
const
auto
&
output
:
sorted_outputs
)
{
out_str
.
append
(
pre_str
+
output
->
Name
());
pre_str
=
", "
;
}
out_str
.
append
(
"} = "
);
std
::
string
in_str
=
"("
;
pre_str
=
""
;
for
(
const
auto
&
input
:
sorted_inputs
)
{
in_str
.
append
(
pre_str
+
input
->
Name
());
pre_str
=
", "
;
}
in_str
.
append
(
")"
);
op_str
=
out_str
+
op_str
+
in_str
;
}
else
{
// A normal Op, has OpDesc, create from ProgramDesc
std
::
string
out_str
=
"{"
;
std
::
string
outer_pre_str
=
""
;
for
(
const
auto
&
output
:
op
->
OutputNames
())
{
out_str
.
append
(
outer_pre_str
+
output
+
"=["
);
std
::
string
inner_pre_str
=
""
;
for
(
const
auto
&
arg
:
op
->
Output
(
output
))
{
out_str
.
append
(
inner_pre_str
+
arg
);
inner_pre_str
=
" ,"
;
}
outer_pre_str
=
", "
;
out_str
.
append
(
"]"
);
}
out_str
.
append
(
"} = "
);
std
::
string
in_str
=
"("
;
outer_pre_str
=
""
;
for
(
const
auto
&
input
:
op
->
InputNames
())
{
in_str
.
append
(
outer_pre_str
+
input
+
"=["
);
std
::
string
inner_pre_str
=
""
;
for
(
const
auto
&
arg
:
op
->
Input
(
input
))
{
in_str
.
append
(
inner_pre_str
+
arg
);
inner_pre_str
=
" ,"
;
}
outer_pre_str
=
" ,"
;
in_str
.
append
(
"]"
);
}
in_str
.
append
(
")"
);
op_str
=
out_str
+
op_str
+
in_str
;
}
return
op_str
;
}
return
Name
();
}
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
std
::
vector
<
Node
*>
outputs
;
// Because NO_DESC_ORDER is a constexpr number,
// no one can change it, meanwhile, we need
// check whether the DescOrder invalid sometime,
// so expose it is a good idea
static
constexpr
int
NO_DESC_ORDER
=
INT_MAX
;
protected:
protected:
std
::
string
name_
;
std
::
string
name_
;
std
::
unique_ptr
<
VarDesc
>
var_desc_
;
std
::
unique_ptr
<
VarDesc
>
var_desc_
;
...
@@ -146,30 +235,45 @@ class Node {
...
@@ -146,30 +235,45 @@ class Node {
Type
type_
;
Type
type_
;
int
id_
;
int
id_
;
int
desc_order_
;
int
block_id_
{
-
1
};
private:
private:
// ID can only set by a Graph.
// ID can only set by a Graph.
void
SetId
(
int
id
)
{
id_
=
id
;
}
void
SetId
(
int
id
)
{
id_
=
id
;
}
// desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc.
void
SetDescOrder
(
int
desc_order
)
{
desc_order_
=
desc_order
;
}
friend
class
Graph
;
friend
class
Graph
;
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
const
std
::
string
&
name
,
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
const
std
::
string
&
name
,
Node
::
Type
type
);
Node
::
Type
type
);
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
VarDesc
*
var_desc
);
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
VarDesc
*
var_desc
);
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
OpDesc
*
op_desc
);
friend
std
::
unique_ptr
<
Node
>
CreateNodeForTest
(
OpDesc
*
op_desc
);
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
explicit
Node
(
const
std
::
string
&
name
,
Type
type
,
int
block_id
=
0
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
),
desc_order_
(
NO_DESC_ORDER
),
block_id_
(
block_id
)
{}
explicit
Node
(
VarDesc
*
var_desc
)
explicit
Node
(
VarDesc
*
var_desc
,
int
block_id
)
:
name_
(
var_desc
->
Name
()),
:
name_
(
var_desc
->
Name
()),
var_desc_
(
new
VarDesc
(
*
var_desc
)),
var_desc_
(
new
VarDesc
(
*
var_desc
)),
op_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
Type
::
kVariable
)
{}
type_
(
Type
::
kVariable
),
desc_order_
(
NO_DESC_ORDER
),
block_id_
(
block_id
)
{}
explicit
Node
(
OpDesc
*
op_desc
)
explicit
Node
(
OpDesc
*
op_desc
)
:
name_
(
op_desc
->
Type
()),
:
name_
(
op_desc
->
Type
()),
var_desc_
(
nullptr
),
var_desc_
(
nullptr
),
op_desc_
(
new
OpDesc
(
*
op_desc
,
op_desc
->
Block
())),
op_desc_
(
new
OpDesc
(
*
op_desc
,
op_desc
->
Block
())),
type_
(
Type
::
kOperation
)
{}
type_
(
Type
::
kOperation
),
desc_order_
(
NO_DESC_ORDER
)
{}
Node
()
=
delete
;
Node
()
=
delete
;
...
...
paddle/fluid/framework/ir/node_test.cc
浏览文件 @
167523e7
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -75,6 +76,32 @@ TEST(NodeTest, Basic) {
...
@@ -75,6 +76,32 @@ TEST(NodeTest, Basic) {
EXPECT_FALSE
(
alive2
);
EXPECT_FALSE
(
alive2
);
}
}
TEST
(
NodeTest
,
ToString
)
{
VarDesc
var_desc
(
"n2"
);
OpDesc
op_desc
;
op_desc
.
SetType
(
"test_op"
);
op_desc
.
SetInput
(
"X"
,
{
"x1"
,
"x2"
,
"x3"
});
op_desc
.
SetOutput
(
"Y"
,
{
"y1"
,
"y2"
});
std
::
unique_ptr
<
Node
>
n1
(
CreateNodeForTest
(
"n1"
,
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
Node
>
n2
(
CreateNodeForTest
(
&
var_desc
));
std
::
unique_ptr
<
Node
>
n3
(
CreateNodeForTest
(
"n3"
,
Node
::
Type
::
kOperation
));
std
::
unique_ptr
<
Node
>
n4
(
CreateNodeForTest
(
&
op_desc
));
EXPECT_EQ
(
n1
->
ToString
(),
"n1"
);
EXPECT_EQ
(
n2
->
ToString
(),
"n2"
);
EXPECT_EQ
(
n3
->
Op
(),
nullptr
);
EXPECT_EQ
(
n3
->
ToString
(),
"{} = n3()"
);
EXPECT_NE
(
n4
->
Op
(),
nullptr
);
EXPECT_EQ
(
n4
->
ToString
(),
"{Y=[y1 ,y2]} = test_op(X=[x1 ,x2 ,x3])"
);
n3
->
inputs
.
push_back
(
n1
.
get
());
n3
->
outputs
.
push_back
(
n2
.
get
());
EXPECT_EQ
(
n3
->
Op
(),
nullptr
);
EXPECT_EQ
(
n3
->
ToString
(),
"{n2} = n3(n1)"
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录