Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d6d3e6af
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6d3e6af
编写于
1月 28, 2019
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more skip strategy
上级
2739096e
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
425 addition
and
93 deletion
+425
-93
paddle/fluid/framework/details/graph_print_pass.cc
paddle/fluid/framework/details/graph_print_pass.cc
+64
-1
paddle/fluid/framework/details/graph_print_pass.h
paddle/fluid/framework/details/graph_print_pass.h
+2
-0
paddle/fluid/framework/details/graph_print_pass_test.cc
paddle/fluid/framework/details/graph_print_pass_test.cc
+111
-0
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+170
-78
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+21
-1
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+25
-6
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+5
-0
paddle/fluid/framework/ir/graph_helper_test.cc
paddle/fluid/framework/ir/graph_helper_test.cc
+11
-0
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+4
-5
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
+12
-2
未找到文件。
paddle/fluid/framework/details/graph_print_pass.cc
浏览文件 @
d6d3e6af
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -54,6 +55,11 @@ class GraphvizOp : public GraphvizNode {
}
}
template
<
typename
Callback
>
void
AddCustomEdge
(
const
Callback
&
cb
)
{
stream_
<<
cb
()
<<
std
::
endl
;
}
private:
std
::
ostringstream
stream_
;
};
...
...
@@ -68,12 +74,47 @@ std::vector<T*> FilterByNodeWrapper(const Container& con) {
return
ret
;
}
// bool DetectCircleRecursive(const std::map<ir::Node*,
// std::unordered_set<ir::Node*>>, std::unordered_set<ir::Node*>* visited,
// std::unordered_set<ir::Node*> *in_trace, std::vector<std::vector<ir::Node*>>*
// circles) {
// if (visited->find(node) == visited->end()) {
// visited->insert(node);
// in_trace->insert(node);
// for (ir::Node *in : adj_list.at(node)) {
// if (visited->find(in) == visited->end() &&
// HasCircleHelper(in, adj_list, visited, in_trace)) {
// return true;
// } else if (in_trace->find(in) != in_trace->end()) {
// circles->push_back(in_trace);
// return true;
// }
// }
// }
// in_trace->erase(node);
// return false;
// }
// bool DetectCircle(const std::map<ir::Node*, std::unordered_set<ir::Node*>>&
// adj_list, std::vector<std::vector<ir::Node*>>* circles) {
// std::unordered_set<ir::Node *> visited;
// std::unordered_set<ir::Node *> in_trace;
// bool has_circle = false;
// for(auto& adj : adj_list) {
// has_circle &= DetectCircleRecursive(adj, adj_list,&visited, &in_trace,
// circles);
// }
// return has_circle;
// }
std
::
unordered_map
<
ir
::
Node
*
,
int
>
SSAGraphPrinterImpl
::
ToGraphvizNode
(
const
ir
::
Graph
&
graph
)
const
{
// Convert to GraphvizNode format
auto
&
graphviz_nodes
=
graph
.
Get
<
GraphvizNodes
>
(
kGraphviz
);
graphviz_nodes
.
clear
();
std
::
unordered_map
<
ir
::
Node
*
,
int
>
vars
;
std
::
unordered_map
<
ir
::
Node
*
,
GraphvizOp
*>
ops
;
int
var_id
=
0
;
int
op_id
=
0
;
for
(
auto
&
node
:
graph
.
Nodes
())
{
...
...
@@ -81,11 +122,33 @@ std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
graphviz_nodes
.
emplace
(
new
GraphvizVar
(
node
,
var_id
));
vars
.
emplace
(
std
::
make_pair
(
node
,
var_id
++
));
}
else
if
(
node
->
IsOp
())
{
graphviz_nodes
.
emplace
(
new
GraphvizOp
(
node
,
op_id
++
));
std
::
unique_ptr
<
GraphvizOp
>
op
(
new
GraphvizOp
(
node
,
op_id
++
));
ops
[
node
]
=
op
.
get
();
graphviz_nodes
.
emplace
(
std
::
move
(
op
));
// graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
// ops.emplace(std::make_pair(node, graphviz_nodes.back().get()));
}
else
{
PADDLE_THROW
(
"Unknown op type"
);
}
}
// Detect circle. Draw circle in different lines
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
circles
;
const
std
::
string
kCircleEdge
=
"[color=red,penwidth=3.0]"
;
if
(
ir
::
FindCircleSubGraph
(
graph
,
&
circles
))
{
VLOG
(
3
)
<<
"Graph has circle! circles count : "
<<
circles
.
size
();
for
(
auto
&
circle
:
circles
)
{
for
(
size_t
i
=
0
;
i
<
circle
.
size
()
-
1
;
++
i
)
{
GraphvizOp
*
prev
=
ops
[
circle
[
i
]];
GraphvizOp
*
next
=
ops
[
circle
[
i
+
1
]];
std
::
string
prev_op
=
"op_"
+
std
::
to_string
(
prev
->
Id
());
std
::
string
next_op
=
"op_"
+
std
::
to_string
(
next
->
Id
());
prev
->
AddCustomEdge
([
&
]()
->
std
::
string
{
return
prev_op
+
"->"
+
next_op
+
kCircleEdge
;
});
}
}
}
return
vars
;
}
...
...
paddle/fluid/framework/details/graph_print_pass.h
浏览文件 @
d6d3e6af
...
...
@@ -31,6 +31,8 @@ class GraphvizNode {
GraphvizNode
(
ir
::
Node
*
n
,
const
int
&
i
)
:
node_
(
n
),
id_
(
i
)
{}
virtual
~
GraphvizNode
()
=
default
;
int
Id
()
const
{
return
id_
;
}
protected:
ir
::
Node
*
node_
;
int
id_
;
...
...
paddle/fluid/framework/details/graph_print_pass_test.cc
浏览文件 @
d6d3e6af
...
...
@@ -19,6 +19,9 @@ REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle
::
framework
::
SumOpMaker
);
REGISTER_OPERATOR
(
split
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SplitOpMaker
);
REGISTER_OPERATOR
(
assign
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
AssignOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
/*
a @ b
...
...
@@ -54,6 +57,12 @@ inline static ProgramDesc FillProgramDesc() {
op
->
SetInput
(
"X"
,
{
"d"
,
"e"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"d"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
return
prog
;
}
...
...
@@ -74,6 +83,108 @@ TEST(SSAGraphPrinter, Normal) {
printer
->
Print
(
*
graph
,
*
fout
);
}
using
ir
::
Graph
;
using
ir
::
Node
;
void
BuildCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o1
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o1
);
}
void
BuildCircleGraph2
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
o2
->
outputs
.
push_back
(
v2
);
o1
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o1
);
}
void
BuildNoCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o3
=
g
->
CreateEmptyNode
(
"op3"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o4
=
g
->
CreateEmptyNode
(
"op4"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o5
=
g
->
CreateEmptyNode
(
"op5"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v3
=
g
->
CreateEmptyNode
(
"var3"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v4
=
g
->
CreateEmptyNode
(
"var4"
,
Node
::
Type
::
kVariable
);
// o1->v1->o2
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
// o2->v2->o3
// o2->v2->o4
o2
->
outputs
.
push_back
(
v2
);
o3
->
inputs
.
push_back
(
v2
);
o4
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o3
);
v2
->
outputs
.
push_back
(
o4
);
// o2->v3->o5
o2
->
outputs
.
push_back
(
v3
);
o5
->
inputs
.
push_back
(
v3
);
v3
->
inputs
.
push_back
(
o2
);
v3
->
outputs
.
push_back
(
o5
);
// o3-v4->o5
o3
->
outputs
.
push_back
(
v4
);
o5
->
inputs
.
push_back
(
v4
);
v4
->
inputs
.
push_back
(
o3
);
v4
->
outputs
.
push_back
(
o5
);
// o2->v3->o1
v3
->
outputs
.
push_back
(
o1
);
o1
->
inputs
.
push_back
(
v3
);
}
TEST
(
SSAGraphPrinter
,
SimpleCircle
)
{
ProgramDesc
prog
;
Graph
graph
(
prog
);
BuildCircleGraph
(
&
graph
);
ASSERT_TRUE
(
HasCircle
(
graph
));
graph
.
Set
<
GraphvizNodes
>
(
kGraphviz
,
new
GraphvizNodes
);
std
::
unique_ptr
<
SSAGraphPrinter
>
printer
(
new
SSAGraphPrinterImpl
);
// redirect debug graph to a file.
constexpr
char
graph_path
[]
=
"graph_print_pass_simple_circle.txt"
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_path
));
PADDLE_ENFORCE
(
fout
->
good
());
printer
->
Print
(
graph
,
*
fout
);
}
TEST
(
SSAGraphPrinter
,
ComplexCircle
)
{
ProgramDesc
prog
;
Graph
graph
(
prog
);
BuildCircleGraph2
(
&
graph
);
ASSERT_TRUE
(
HasCircle
(
graph
));
graph
.
Set
<
GraphvizNodes
>
(
kGraphviz
,
new
GraphvizNodes
);
std
::
unique_ptr
<
SSAGraphPrinter
>
printer
(
new
SSAGraphPrinterImpl
);
// redirect debug graph to a file.
constexpr
char
graph_path
[]
=
"graph_print_pass_complex_circle.txt"
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_path
));
PADDLE_ENFORCE
(
fout
->
good
());
printer
->
Print
(
graph
,
*
fout
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
d6d3e6af
...
...
@@ -23,6 +23,7 @@
#include <vector>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
...
...
@@ -39,16 +40,20 @@
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
// out_ptr[0] = 0; // input contect is overwrited.
// For backward compacity. if enable_inplace_whitelist is turn on.
// NOTE(dzhwinter):
// Only for backward compacity and stable. if enable_inplace_whitelist is turn
// on.
// only the ops in whitelist will be use inplace strategy.
// if not, all the op will be inplaced if it registered with InplaceClass
DEFINE_bool
(
enable_inplace_whitelist
,
tru
e
,
enable_inplace_whitelist
,
fals
e
,
"If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add"
"By default, it's turned on"
);
DECLARE_string
(
memory_optimize_debug
);
// clang-format off
const
std
::
string
kInplacedOpWhiteList
[]
=
{
// NOLINT
"sigmoid"
,
...
...
@@ -77,63 +82,6 @@ namespace paddle {
namespace
framework
{
namespace
details
{
static
inline
std
::
string
NodeDebugString
(
ir
::
Node
*
var
)
{
std
::
ostringstream
os
;
if
(
var
->
IsCtrlVar
())
{
os
<<
"kControlDepVarName"
<<
" "
;
}
else
if
(
var
->
IsOp
())
{
os
<<
"kOperation"
<<
" "
<<
var
->
Name
();
PADDLE_ENFORCE
(
var
->
Op
()
!=
nullptr
&&
var
->
Op
()
->
Type
()
==
var
->
Name
());
}
else
if
(
var
->
IsVar
())
{
os
<<
"kVariable"
<<
" "
<<
var
->
Name
();
PADDLE_ENFORCE
(
var
->
Var
()
!=
nullptr
&&
var
->
Var
()
->
Name
()
==
var
->
Name
());
}
else
{
PADDLE_THROW
(
"Unknown node type."
);
}
return
os
.
str
();
}
static
inline
std
::
string
OpDebugString
(
ir
::
Node
*
var
)
{
ir
::
Node
*
op
=
var
;
if
(
var
->
IsVar
())
op
=
var
->
inputs
.
at
(
0
);
std
::
stringstream
os
;
os
<<
op
->
Name
()
<<
" : "
;
os
<<
"Input "
;
VLOG
(
3
)
<<
op
->
Name
();
for
(
auto
*
var
:
op
->
inputs
)
{
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
PADDLE_ENFORCE
(
var
->
Var
()
!=
nullptr
&&
var
->
Var
()
->
Name
()
==
var
->
Name
(),
"unmatched desc and var"
);
// os << var << ":" << var->Name() << " ";
os
<<
var
->
Name
()
<<
" "
;
}
}
os
<<
"Output "
;
VLOG
(
3
)
<<
op
->
Name
();
for
(
auto
*
var
:
op
->
outputs
)
{
VLOG
(
3
)
<<
var
;
VLOG
(
3
)
<<
var
->
Name
();
if
(
!
var
->
IsVar
())
{
VLOG
(
3
)
<<
"error"
;
}
// VLOG(3) << var->Var()->Name();
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
PADDLE_ENFORCE
(
var
->
Var
()
!=
nullptr
&&
var
->
Var
()
->
Name
()
==
var
->
Name
(),
"unmatched desc and var"
);
// os << var << ":" << var->Name() << " ";
os
<<
var
->
Name
()
<<
" "
;
}
if
(
var
->
Name
()
==
"fc_10.tmp_0"
)
{
VLOG
(
3
)
<<
NodeDebugString
(
var
);
}
}
return
os
.
str
();
}
static
inline
ir
::
Node
*
GetNextCascadeInplacedVar
(
ir
::
Node
*
var
)
{
// if next op is inplaced, then return the output var
// otherwise return nullptr
...
...
@@ -218,6 +166,10 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
InitSSAGraphNodes
();
std
::
unique_ptr
<
SSAGraphPrinter
>
printer
(
new
SSAGraphPrinterImpl
);
constexpr
char
graph_path1
[]
=
"ir_graph_before_inplaced.txt"
;
std
::
unique_ptr
<
std
::
ostream
>
fout1
(
new
std
::
ofstream
(
graph_path1
));
PADDLE_ENFORCE
(
fout1
->
good
());
printer
->
Print
(
*
graph
,
*
fout1
);
for
(
auto
*
op
:
view_
.
AllOps
())
{
if
(
FLAGS_enable_inplace_whitelist
&&
!
whitelist_
.
count
(
op
->
Name
()))
...
...
@@ -230,9 +182,6 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_path
));
PADDLE_ENFORCE
(
fout
->
good
());
printer
->
Print
(
*
graph
,
*
fout
);
// for(auto* op : view_.AllOps()) {
// VLOG(3) << OpDebugString(op);
// }
return
graph
;
}
...
...
@@ -250,6 +199,92 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
}
}
const
SSANodeVector
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
SSANodeVector
swap_nodes
;
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
// redirect the input to the latest version of cache_var
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
// swap node to cache_node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
PADDLE_ENFORCE
(
node
->
inputs
.
size
()
==
1
&&
node
->
inputs
[
0
]
->
IsOp
());
auto
*
prev_op
=
node
->
inputs
[
0
];
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
node
,
cache_node
);
cache_node
->
inputs
.
emplace_back
(
prev_op
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
swap_nodes
[
node
].
emplace_back
(
cache_node
);
}
}
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
cache_node
->
inputs
.
emplace_back
(
op
);
std
::
replace
(
op
->
outputs
.
begin
(),
op
->
outputs
.
end
(),
node
,
cache_node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
swap_nodes
[
node
].
emplace_back
(
cache_node
);
}
}
}
return
swap_nodes
;
}
void
InplacePass
::
CommitModify
(
const
SSANodeVector
&
swap_nodes
,
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
swap_nodes
)
{
auto
*
node
=
pair
.
first
;
const
std
::
string
var
=
node
->
Name
();
for
(
auto
*
cache_node
:
pair
.
second
)
{
const
std
::
string
cache_var
=
cache_node
->
Name
();
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
}
auto
&
nodes
=
var_nodes_
.
at
(
var
);
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
graph
->
RemoveNode
(
node
);
}
}
void
InplacePass
::
WithDrawModify
(
const
SSANodeVector
&
nodes
,
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
nodes
)
{
auto
*
node
=
pair
.
first
;
const
std
::
string
var
=
node
->
Name
();
for
(
auto
*
cache_node
:
pair
.
second
)
{
const
std
::
string
cache_var
=
cache_node
->
Name
();
auto
*
prev_op
=
node
->
inputs
[
0
];
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
cache_node
,
node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
cache_node
,
node
);
}
graph
->
RemoveNode
(
cache_node
);
}
}
}
void
InplacePass
::
InplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
...
...
@@ -318,7 +353,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
op
->
Op
()
!=
nullptr
&&
op
->
Op
()
->
Block
()
!=
nullptr
,
"op_desc is nullptr"
);
//
3
pre-requirments need to meet if the op want to inplaced.
//
4
pre-requirments need to meet if the op want to inplaced.
// 1. infer_inplace_ is registered.
auto
*
op_desc
=
op
->
Op
();
auto
&
infer_inplace
=
...
...
@@ -333,36 +368,68 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto
&
all_ops
=
view_
.
AllOps
();
auto
cursor
=
std
::
find
(
all_ops
.
begin
(),
all_ops
.
end
(),
op
);
size_t
idx
=
std
::
distance
(
all_ops
.
begin
(),
cursor
);
VLOG
(
3
)
<<
op
->
Name
()
<<
idx
;
for
(
auto
&
pair
:
in_to_outs
)
{
auto
&
in_var_name
=
pair
.
first
;
auto
&
out_var_name
=
pair
.
second
;
auto
*
in_node
=
view_
.
GetNodeByName
(
in_var_name
,
op
->
inputs
);
auto
*
out_node
=
view_
.
GetNodeByName
(
out_var_name
,
op
->
outputs
);
// 2. there is no external pending op on the input node
if
(
view_
.
PendingOpsOnVar
(
in_node
).
size
()
>
1
)
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"
!!! %s input has external dependency, can not inplaced, %s => %s
"
"
skiped
"
,
o
p
->
Name
(),
out_var_name
,
in_var_name
);
VLOG
(
4
)
<<
string
::
Sprintf
(
"
Skiped pair %s => %s. %s input has external dependency.
"
"
inplace such pair will overwrite the memory.
"
,
o
ut_var_name
,
in_var_name
,
op
->
Name
()
);
continue
;
}
// 3. if output reuse input inplaced, the dependency group is not changed.
// For detail, check
// the function description in "OutConnectInputByCtrlVar"
if
(
view_
.
OutConnectInputByCtrlVar
(
in_node
,
out_node
))
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s input output connect by ctrl var, cannot inplaced, %s => %s "
"skiped"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
VLOG
(
4
)
<<
string
::
Sprintf
(
"Skiped pair %s => %s. %s input and output connect by ctrl var."
"inplace such pair will generate a circle."
,
out_var_name
,
in_var_name
,
op
->
Name
());
continue
;
}
// 4. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if
(
view_
.
ReusedInPythonMemOpt
(
out_node
->
Name
()))
{
VLOG
(
4
)
<<
string
::
Sprintf
(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
"it inplace may generate a circle"
,
out_var_name
,
in_var_name
,
op
->
Name
());
continue
;
}
// Debug Interface. Which would be skipped by the pass.
if
(
out_node
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"Skiped var by force. FLAGS_memory_optimize_debug="
<<
out_node
->
Name
();
continue
;
}
auto
swap_nodes
=
TryInplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
// NOTE(dzhwinter):
// two stage commit of inplaced op. If add such node generate a circle,
// then withdraw the changes. Otherwise, safely add the node.
if
(
!
ir
::
HasCircle
(
*
graph
))
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s inplaced"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
// VLOG(3) << "Out " << OpDebugString(op
);
CommitModify
(
swap_nodes
,
graph
);
InplaceModifyDesc
(
out_var_name
,
in_var_name
,
idx
);
InplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
}
else
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s"
,
out_var_name
,
in_var_name
,
op
->
Name
());
WithDrawModify
(
swap_nodes
,
graph
);
}
}
}
...
...
@@ -406,7 +473,28 @@ std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) {
return
pending_ops
;
}
void
GraphView
::
Build
(
ir
::
Graph
*
g
)
{
ops_
=
SortOpLikeDescOrder
(
*
g
);
}
void
GraphView
::
Build
(
ir
::
Graph
*
g
)
{
// track the var nodes in correct order.
// Because we insert some new created node. Which may have data race between
// nodes.
// resolve data harzards depends on the var nodes in right order.
ops_
=
SortOpLikeDescOrder
(
*
g
);
// track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std
::
unordered_set
<
std
::
string
>
all_vars
;
for
(
auto
&
node
:
g
->
Nodes
())
{
if
(
node
->
IsVar
())
continue
;
for
(
auto
&
out
:
node
->
outputs
)
{
if
(
out
->
IsCtrlVar
()
||
out
->
Var
()
==
nullptr
)
continue
;
if
(
all_vars
.
count
(
out
->
Name
()))
{
dup_nodes_
.
emplace
(
out
->
Name
());
}
else
{
all_vars
.
emplace
(
out
->
Name
());
}
}
}
}
const
std
::
vector
<
ir
::
Node
*>
GraphView
::
AllOps
()
{
return
ops_
;
}
...
...
@@ -452,6 +540,10 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) {
return
ConnectByCtrlVar
(
in_var_set
,
out_var_set
);
}
bool
GraphView
::
ReusedInPythonMemOpt
(
const
std
::
string
&
var
)
const
{
return
dup_nodes_
.
count
(
var
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/details/inplace_op_pass.h
浏览文件 @
d6d3e6af
...
...
@@ -2,7 +2,7 @@
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may
o
btain a copy of the License at
// You may
a
btain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
...
...
@@ -40,10 +41,20 @@ class GraphView {
bool
OutConnectInputByCtrlVar
(
ir
::
Node
*
in_var
,
ir
::
Node
*
out_var
);
// Will Deperated in the future.
// NOTE(dzhwinter) : Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
bool
ReusedInPythonMemOpt
(
const
std
::
string
&
var
)
const
;
private:
std
::
vector
<
ir
::
Node
*>
ops_
;
std
::
unordered_set
<
std
::
string
>
dup_nodes_
;
// mem opt affect nodes
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
};
typedef
std
::
unordered_map
<
ir
::
Node
*
,
std
::
vector
<
ir
::
Node
*>>
SSANodeVector
;
class
InplacePass
:
public
ir
::
Pass
{
public:
InplacePass
();
...
...
@@ -58,6 +69,15 @@ class InplacePass : public ir::Pass {
void
InplaceModifyVar
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
const
SSANodeVector
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
SSANodeVector
&
,
ir
::
Graph
*
graph
)
const
;
void
WithDrawModify
(
const
SSANodeVector
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
)
const
;
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
d6d3e6af
...
...
@@ -52,16 +52,29 @@ bool HasCircleHelper(
ir
::
Node
*
node
,
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
std
::
unordered_set
<
ir
::
Node
*>
*
visited
,
std
::
unordered_set
<
ir
::
Node
*>
*
in_trace
)
{
std
::
unordered_set
<
ir
::
Node
*>
*
in_trace
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
)
{
if
(
visited
->
find
(
node
)
==
visited
->
end
())
{
visited
->
insert
(
node
);
in_trace
->
insert
(
node
);
for
(
ir
::
Node
*
in
:
adj_list
.
at
(
node
))
{
if
(
visited
->
find
(
in
)
==
visited
->
end
()
&&
HasCircleHelper
(
in
,
adj_list
,
visited
,
in_trace
))
{
HasCircleHelper
(
in
,
adj_list
,
visited
,
in_trace
,
circles
))
{
return
true
;
}
else
if
(
in_trace
->
find
(
in
)
!=
in_trace
->
end
())
{
if
(
circles
!=
nullptr
)
{
std
::
vector
<
ir
::
Node
*>
circle
;
circle
.
emplace_back
(
in
);
ir
::
Node
*
p
=
in
;
for
(
auto
&
adj
:
adj_list
.
at
(
p
))
{
if
(
in_trace
->
count
(
adj
))
{
circle
.
emplace_back
(
adj
);
p
=
adj
;
}
}
circles
->
emplace_back
(
circle
);
}
return
true
;
}
}
...
...
@@ -71,11 +84,12 @@ bool HasCircleHelper(
}
bool
HasCircleInternal
(
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
)
{
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
)
{
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
in_trace
;
for
(
auto
&
adj
:
adj_list
)
{
if
(
HasCircleHelper
(
adj
.
first
,
adj_list
,
&
visited
,
&
in_trace
))
{
if
(
HasCircleHelper
(
adj
.
first
,
adj_list
,
&
visited
,
&
in_trace
,
circles
))
{
return
true
;
}
}
...
...
@@ -84,13 +98,18 @@ bool HasCircleInternal(
}
// namespace
bool
HasCircle
(
const
Graph
&
graph
)
{
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
));
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
),
nullptr
);
}
bool
FindCircleSubGraph
(
const
Graph
&
graph
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
)
{
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
),
circles
);
}
std
::
vector
<
ir
::
Node
*>
TopologySortOperations
(
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
BuildOperationAdjList
(
graph
);
PADDLE_ENFORCE
(
!
HasCircleInternal
(
adj_list
));
PADDLE_ENFORCE
(
!
HasCircleInternal
(
adj_list
,
nullptr
));
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
vector
<
ir
::
Node
*>
ret
;
for
(
auto
adj
:
adj_list
)
{
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
d6d3e6af
...
...
@@ -28,6 +28,11 @@ namespace ir {
// Test if the graph contains circle.
bool
HasCircle
(
const
Graph
&
graph
);
// Find All Circles for debugging,
// store all subgraph in circles.
bool
FindCircleSubGraph
(
const
Graph
&
graph
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
);
size_t
GraphNum
(
const
Graph
&
graph
);
// Topology Sort the operations in the graph from inputs to outputs.
...
...
paddle/fluid/framework/ir/graph_helper_test.cc
浏览文件 @
d6d3e6af
...
...
@@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) {
// v4->outputs.push_back(o5);
}
TEST
(
GraphHelperTest
,
Circles
)
{
ProgramDesc
prog
;
Graph
g
(
prog
);
BuildCircleGraph
(
&
g
);
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
circles
;
ASSERT_TRUE
(
FindCircleSubGraph
(
g
,
&
circles
));
ASSERT_EQ
(
circles
.
size
()
==
1UL
);
}
TEST
(
GraphHelperTest
,
GraphNum
)
{
ProgramDesc
prog
;
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
d6d3e6af
...
...
@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def
check_network_convergence
(
self
,
method
,
use_cuda
=
True
,
memory_opt
=
Tru
e
,
memory_opt
=
Fals
e
,
iter
=
50
,
batch_size
=
None
,
allow_op_delay
=
False
,
...
...
@@ -67,8 +67,6 @@ class TestParallelExecutorBase(unittest.TestCase):
if
memory_opt
:
fluid
.
memory_optimize
(
main
)
with
open
(
"program_model.txt"
,
"w"
)
as
f
:
f
.
write
(
str
(
main
))
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
...
...
@@ -82,9 +80,10 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_relu_depthwise_conv
=
fuse_relu_depthwise_conv
build_strategy
.
memory_optimize
=
use_ir_memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way.
build_strategy
.
enable_inplace
=
False
if
memory_opt
else
enable_inplace
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
build_strategy
.
debug_graphviz_path
=
"debug_ir_graph_"
if
use_cuda
and
core
.
is_compiled_with_cuda
():
build_strategy
.
remove_unnecessary_lock
=
True
...
...
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
浏览文件 @
d6d3e6af
...
...
@@ -46,7 +46,10 @@ class TestIrInplace(TestParallelExecutorBase):
def
setUpClass
(
cls
):
os
.
environ
[
'CPU_NUM'
]
=
str
(
4
)
def
_fc_with_batchnorm
(
self
,
ir_memory_optimize
,
enable_inplace
):
def
_fc_with_batchnorm
(
self
,
ir_memory_optimize
,
enable_inplace
,
memory_opt
=
False
):
np
.
random
.
seed
(
5
)
img
=
np
.
random
.
random
(
size
=
[
32
,
784
]).
astype
(
np
.
float32
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
...
...
@@ -55,7 +58,7 @@ class TestIrInplace(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
True
,
memory_opt
=
False
,
# inplace is conflict with memory opt
memory_opt
=
memory_opt
,
use_ir_memory_optimize
=
ir_memory_optimize
,
enable_inplace
=
enable_inplace
)
...
...
@@ -67,3 +70,10 @@ class TestIrInplace(TestParallelExecutorBase):
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss11
,
delta
=
delta
)
def
test_fc_with_batchnorm_memory_opt
(
self
,
delta
=
1e-3
):
loss00
=
self
.
_fc_with_batchnorm
(
False
,
True
,
False
)
loss10
=
self
.
_fc_with_batchnorm
(
False
,
True
,
True
)
loss10
=
self
.
_fc_with_batchnorm
(
True
,
True
,
True
)
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录