Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8156fedf
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8156fedf
编写于
1月 29, 2019
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge develop branch. test=develop
上级
ee3aae56
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
46 addition
and
114 deletion
+46
-114
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+37
-96
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+8
-10
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+1
-1
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
+0
-7
未找到文件。
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
8156fedf
...
@@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
...
@@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
}
}
}
}
const
SSANodeVector
InplacePass
::
TryInplaceModifyVar
(
const
SSANodePair
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
const
std
::
string
&
cache_var
,
ir
::
Graph
*
graph
)
const
{
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
var_desc
->
SetName
(
cache_var
);
SSANodeVector
swap_nodes
;
SSANodePair
swap_nodes
;
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
auto
*
op
=
view_
.
AllOps
()[
i
];
...
@@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
...
@@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
for
(
auto
*
node
:
op
->
inputs
)
{
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
node
->
Name
()
==
var
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
// swap node to cache_node
// swap node to cache_node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
...
@@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
...
@@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
cache_node
);
cache_node
);
}
}
swap_nodes
[
node
].
emplace_back
(
cache_node
);
swap_nodes
.
emplace_back
(
std
::
make_pair
(
node
,
cache_node
)
);
}
}
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for
(
auto
*
node
:
op
->
outputs
)
{
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
node
->
Name
()
==
var
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache node
// swap node to cache node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
...
@@ -244,108 +249,43 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
...
@@ -244,108 +249,43 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
cache_node
);
}
}
swap_nodes
[
node
].
emplace_back
(
cache_node
);
swap_nodes
.
emplace_back
(
std
::
make_pair
(
node
,
cache_node
));
}
}
}
}
}
}
return
swap_nodes
;
return
swap_nodes
;
}
}
void
InplacePass
::
CommitModify
(
const
SSANode
Vecto
r
&
swap_nodes
,
void
InplacePass
::
CommitModify
(
const
SSANode
Pai
r
&
swap_nodes
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
swap_nodes
)
{
for
(
auto
&
pair
:
swap_nodes
)
{
auto
*
node
=
pair
.
first
;
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
const
std
::
string
var
=
node
->
Name
();
const
std
::
string
var
=
node
->
Name
(),
cache_var
=
cache_node
->
Name
();
for
(
auto
*
cache_node
:
pair
.
second
)
{
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
const
std
::
string
cache_var
=
cache_node
->
Name
();
graph
->
RemoveNode
(
node
);
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
}
auto
&
nodes
=
var_nodes_
.
at
(
var
);
auto
&
nodes
=
var_nodes_
.
at
(
var
);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
graph
->
RemoveNode
(
node
);
}
}
}
}
void
InplacePass
::
With
DrawModify
(
const
SSANodeVecto
r
&
nodes
,
void
InplacePass
::
With
drawModify
(
const
SSANodePai
r
&
nodes
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
nodes
)
{
for
(
auto
&
pair
:
nodes
)
{
auto
*
node
=
pair
.
first
;
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
const
std
::
string
var
=
node
->
Name
();
const
std
::
string
var
=
node
->
Name
(),
cache_var
=
cache_node
->
Name
();
for
(
auto
*
cache_node
:
pair
.
second
)
{
auto
*
prev_op
=
node
->
inputs
[
0
];
const
std
::
string
cache_var
=
cache_node
->
Name
();
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
cache_node
,
auto
*
prev_op
=
node
->
inputs
[
0
];
node
);
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
cache_node
,
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
cache_node
,
node
);
node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
cache_node
,
node
);
}
graph
->
RemoveNode
(
cache_node
);
}
}
}
void
InplacePass
::
InplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
// redirect the input to the latest version of cache_var
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache_node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
PADDLE_ENFORCE
(
node
->
inputs
.
size
()
==
1
&&
node
->
inputs
[
0
]
->
IsOp
());
auto
*
prev_op
=
node
->
inputs
[
0
];
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
node
,
cache_node
);
cache_node
->
inputs
.
emplace_back
(
prev_op
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
graph
->
RemoveNode
(
node
);
auto
&
nodes
=
var_nodes_
.
at
(
var
);
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
cache_node
->
inputs
.
emplace_back
(
op
);
std
::
replace
(
op
->
outputs
.
begin
(),
op
->
outputs
.
end
(),
node
,
cache_node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
// release unsed var in graph
graph
->
RemoveNode
(
node
);
auto
&
nodes
=
var_nodes_
.
at
(
var
);
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
}
}
}
graph
->
RemoveNode
(
cache_node
);
}
}
}
}
...
@@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
...
@@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
continue
;
continue
;
}
}
// NOTE(dzhwinter):
// two stage commit of inplaced process. if after inplace happens generate a
// circle,
// then withdraw the changes. Otherwise, safely add the node.
auto
swap_nodes
=
auto
swap_nodes
=
TryInplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
TryInplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
// NOTE(dzhwinter):
// two stage commit of inplaced op. If add such node generate a circle,
// then withdraw the changes. Otherwise, safely add the node.
if
(
!
ir
::
HasCircle
(
*
graph
))
{
if
(
!
ir
::
HasCircle
(
*
graph
))
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s inplaced"
,
op
->
Name
(),
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s inplaced"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
out_var_name
,
in_var_name
);
CommitModify
(
swap_nodes
,
graph
);
InplaceModifyDesc
(
out_var_name
,
in_var_name
,
idx
);
InplaceModifyDesc
(
out_var_name
,
in_var_name
,
idx
);
CommitModify
(
swap_nodes
,
graph
);
}
else
{
}
else
{
VLOG
(
3
)
<<
string
::
Sprintf
(
VLOG
(
3
)
<<
string
::
Sprintf
(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s"
,
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s"
,
out_var_name
,
in_var_name
,
op
->
Name
());
out_var_name
,
in_var_name
,
op
->
Name
());
With
D
rawModify
(
swap_nodes
,
graph
);
With
d
rawModify
(
swap_nodes
,
graph
);
}
}
}
}
}
}
...
...
paddle/fluid/framework/details/inplace_op_pass.h
浏览文件 @
8156fedf
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
...
@@ -54,7 +55,7 @@ class GraphView {
...
@@ -54,7 +55,7 @@ class GraphView {
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
};
};
typedef
std
::
unordered_map
<
ir
::
Node
*
,
std
::
vector
<
ir
::
Node
*>>
SSANodeVecto
r
;
typedef
std
::
vector
<
std
::
pair
<
ir
::
Node
*
,
ir
::
Node
*>>
SSANodePai
r
;
class
InplacePass
:
public
ir
::
Pass
{
class
InplacePass
:
public
ir
::
Pass
{
public:
public:
InplacePass
();
InplacePass
();
...
@@ -66,17 +67,14 @@ class InplacePass : public ir::Pass {
...
@@ -66,17 +67,14 @@ class InplacePass : public ir::Pass {
void
InitSSAGraphNodes
()
const
;
void
InitSSAGraphNodes
()
const
;
private:
private:
void
InplaceModifyVar
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
SSANodePair
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
const
SSANodeVector
TryInplaceModifyVar
(
const
std
::
string
&
var
,
void
CommitModify
(
const
SSANodePair
&
,
ir
::
Graph
*
graph
)
const
;
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
SSANodeVector
&
,
ir
::
Graph
*
graph
)
const
;
void
WithdrawModify
(
const
SSANodePair
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
WithDrawModify
(
const
SSANodeVector
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
)
const
;
const
size_t
&
idx
)
const
;
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
8156fedf
...
@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def
check_network_convergence
(
self
,
def
check_network_convergence
(
self
,
method
,
method
,
use_cuda
=
True
,
use_cuda
=
True
,
memory_opt
=
Fals
e
,
memory_opt
=
Tru
e
,
iter
=
50
,
iter
=
50
,
batch_size
=
None
,
batch_size
=
None
,
allow_op_delay
=
False
,
allow_op_delay
=
False
,
...
...
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
浏览文件 @
8156fedf
...
@@ -70,10 +70,3 @@ class TestIrInplace(TestParallelExecutorBase):
...
@@ -70,10 +70,3 @@ class TestIrInplace(TestParallelExecutorBase):
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss11
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss11
,
delta
=
delta
)
def
test_fc_with_batchnorm_memory_opt
(
self
,
delta
=
1e-3
):
loss00
=
self
.
_fc_with_batchnorm
(
False
,
True
,
False
)
loss10
=
self
.
_fc_with_batchnorm
(
False
,
True
,
True
)
loss10
=
self
.
_fc_with_batchnorm
(
True
,
True
,
True
)
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录