Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5ab96d35
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5ab96d35
编写于
4月 18, 2019
作者:
Z
Zeng Jinle
提交者:
GitHub
4月 18, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add warning and skip vars to mem opt passes (#16967)
test=release/1.4
上级
64c1427d
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
202 addition
and
35 deletion
+202
-35
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+3
-1
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+4
-0
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+27
-19
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+1
-1
paddle/fluid/framework/details/memory_optimize_helper.h
paddle/fluid/framework/details/memory_optimize_helper.h
+6
-0
paddle/fluid/framework/details/memory_optimize_pass.cc
paddle/fluid/framework/details/memory_optimize_pass.cc
+9
-4
paddle/fluid/framework/details/memory_optimize_pass.h
paddle/fluid/framework/details/memory_optimize_pass.h
+2
-1
paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc
...uid/framework/details/record_skip_memory_opt_vars_pass.cc
+64
-0
paddle/fluid/framework/inplace_op_inference_test.cc
paddle/fluid/framework/inplace_op_inference_test.cc
+5
-0
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+2
-0
paddle/fluid/pybind/ir.cc
paddle/fluid/pybind/ir.cc
+6
-0
python/paddle/fluid/compiler.py
python/paddle/fluid/compiler.py
+34
-9
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+36
-0
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+3
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
5ab96d35
...
@@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_
...
@@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_
cc_library
(
fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
...
@@ -114,4 +116,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
...
@@ -114,4 +116,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass
)
fuse_adam_op_pass fuse_sgd_op_pass
record_skip_memory_opt_vars_pass
)
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
5ab96d35
...
@@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
}
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
if
(
strategy_
.
enable_sequential_execution_
)
{
if
(
strategy_
.
enable_sequential_execution_
)
{
VLOG
(
10
)
<<
"Add sequential_execution_pass"
;
VLOG
(
10
)
<<
"Add sequential_execution_pass"
;
AppendPass
(
"sequential_execution_pass"
);
AppendPass
(
"sequential_execution_pass"
);
...
@@ -320,3 +323,4 @@ USE_PASS(graph_to_program_pass);
...
@@ -320,3 +323,4 @@ USE_PASS(graph_to_program_pass);
USE_PASS
(
fuse_adam_op_pass
);
USE_PASS
(
fuse_adam_op_pass
);
USE_PASS
(
fuse_sgd_op_pass
);
USE_PASS
(
fuse_sgd_op_pass
);
USE_PASS
(
fuse_all_reduce_op_pass
);
USE_PASS
(
fuse_all_reduce_op_pass
);
USE_PASS
(
record_skip_memory_opt_vars_pass
);
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
5ab96d35
...
@@ -303,7 +303,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
...
@@ -303,7 +303,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto
*
in_node
=
view_
.
GetNodeByName
(
in_var_name
,
op
->
inputs
);
auto
*
in_node
=
view_
.
GetNodeByName
(
in_var_name
,
op
->
inputs
);
auto
*
out_node
=
view_
.
GetNodeByName
(
out_var_name
,
op
->
outputs
);
auto
*
out_node
=
view_
.
GetNodeByName
(
out_var_name
,
op
->
outputs
);
VLOG
(
4
)
<<
"Try to inplace "
<<
in_var_name
<<
" with "
<<
out_var_name
;
VLOG
(
4
)
<<
"Try to replace: "
<<
in_var_name
<<
" => "
<<
out_var_name
;
if
(
view_
.
InSkipSet
(
in_var_name
))
{
VLOG
(
4
)
<<
string
::
Sprintf
(
"SKIP: %s is in skip set"
,
in_var_name
);
continue
;
}
if
(
view_
.
InSkipSet
(
out_var_name
))
{
VLOG
(
4
)
<<
string
::
Sprintf
(
"SKIP: %s is in skip set"
,
out_var_name
);
continue
;
}
if
(
var_nodes_
[
in_var_name
].
back
()
!=
in_node
)
{
if
(
var_nodes_
[
in_var_name
].
back
()
!=
in_node
)
{
VLOG
(
4
)
<<
"SKIP since "
<<
in_var_name
VLOG
(
4
)
<<
"SKIP since "
<<
in_var_name
...
@@ -318,11 +327,15 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
...
@@ -318,11 +327,15 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
<<
out_var_name
<<
" are the same"
;
<<
out_var_name
<<
" are the same"
;
}
else
if
(
!
NodeCanReused
(
in_node
))
{
}
else
if
(
!
NodeCanReused
(
in_node
))
{
can_replace
=
false
;
can_replace
=
false
;
VLOG
(
4
)
<<
"SKIP: Input varia
lb
e "
<<
in_var_name
<<
"cannot be reused"
;
VLOG
(
4
)
<<
"SKIP: Input varia
bl
e "
<<
in_var_name
<<
"cannot be reused"
;
}
else
if
(
!
NodeCanReused
(
out_node
))
{
}
else
if
(
!
NodeCanReused
(
out_node
))
{
can_replace
=
false
;
can_replace
=
false
;
VLOG
(
4
)
<<
"SKIP: Output variable "
<<
out_var_name
VLOG
(
4
)
<<
"SKIP: Output variable "
<<
out_var_name
<<
" cannot be reused"
;
<<
" cannot be reused"
;
}
else
if
(
in_node
->
Var
()
->
GetType
()
!=
out_node
->
Var
()
->
GetType
())
{
can_replace
=
false
;
VLOG
(
4
)
<<
"SKIP: Input type : "
<<
in_node
->
Var
()
->
GetType
()
<<
" does not match Output type : "
<<
out_node
->
Var
()
->
GetType
();
}
else
if
(
details
::
NodeSize
(
*
in_node
->
Var
())
!=
}
else
if
(
details
::
NodeSize
(
*
in_node
->
Var
())
!=
details
::
NodeSize
(
*
out_node
->
Var
()))
{
details
::
NodeSize
(
*
out_node
->
Var
()))
{
can_replace
=
false
;
can_replace
=
false
;
...
@@ -331,8 +344,8 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
...
@@ -331,8 +344,8 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
if
(
!
can_replace
)
continue
;
if
(
!
can_replace
)
continue
;
// 2.
there is no external pending op on the input nod
e
// 2.
If the variable is the input of muliple ops, we need to make sur
e
//
if (view_.PendingOpsOnVar(in_node).size() > 1) {
//
current op has dependecny on other ops use the same variable
if
(
in_node
->
outputs
.
size
()
>
1
&&
!
view_
.
CheckDeps
(
in_node
,
op
))
{
if
(
in_node
->
outputs
.
size
()
>
1
&&
!
view_
.
CheckDeps
(
in_node
,
op
))
{
VLOG
(
4
)
<<
string
::
Sprintf
(
VLOG
(
4
)
<<
string
::
Sprintf
(
"Skiped pair %s => %s. %s input has external dependency."
"Skiped pair %s => %s. %s input has external dependency."
...
@@ -341,17 +354,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
...
@@ -341,17 +354,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
continue
;
continue
;
}
}
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if
(
view_
.
InSkipSet
(
out_node
->
Name
()))
{
VLOG
(
4
)
<<
string
::
Sprintf
(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
"it inplace may generate a circle"
,
out_var_name
,
in_var_name
,
op
->
Name
());
continue
;
}
// Debug Interface. Which would be skipped by the pass.
// Debug Interface. Which would be skipped by the pass.
if
(
out_node
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
if
(
out_node
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"Skiped var by force. FLAGS_memory_optimize_debug="
VLOG
(
3
)
<<
"Skiped var by force. FLAGS_memory_optimize_debug="
...
@@ -519,16 +521,22 @@ void GraphView::Build(ir::Graph* g) {
...
@@ -519,16 +521,22 @@ void GraphView::Build(ir::Graph* g) {
// resolve data harzards depends on the var nodes in right order.
// resolve data harzards depends on the var nodes in right order.
TopoSort
(
g
);
TopoSort
(
g
);
// fill the skip_set_
PADDLE_ENFORCE
(
g
->
Has
(
details
::
kMemOptSkipVars
));
auto
&
mem_opt_whitelist
=
g
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
for
(
const
auto
&
var
:
mem_opt_whitelist
)
skip_set_
.
emplace
(
var
);
// 2. track the nodes which used by parameter server.
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
// pserver can not find each other name.
auto
update_skip_set
=
[
&
](
ir
::
Node
*
node
)
{
auto
update_skip_set
=
[
&
](
ir
::
Node
*
node
)
{
for
(
auto
&
in
:
node
->
inputs
)
{
for
(
auto
&
in
:
node
->
inputs
)
{
if
(
in
->
IsVar
()
&&
in
->
Var
()
!=
nullptr
)
dup_nodes_
.
emplace
(
in
->
Name
());
if
(
in
->
IsVar
()
&&
in
->
Var
()
!=
nullptr
)
{
skip_set_
.
emplace
(
in
->
Name
());
}
}
}
for
(
auto
&
out
:
node
->
outputs
)
{
for
(
auto
&
out
:
node
->
outputs
)
{
if
(
out
->
IsVar
()
&&
out
->
Var
()
!=
nullptr
)
if
(
out
->
IsVar
()
&&
out
->
Var
()
!=
nullptr
)
skip_set_
.
emplace
(
out
->
Name
());
dup_nodes_
.
emplace
(
out
->
Name
());
}
}
};
};
for
(
auto
&
node
:
g
->
Nodes
())
{
for
(
auto
&
node
:
g
->
Nodes
())
{
...
@@ -545,7 +553,7 @@ void GraphView::Build(ir::Graph* g) {
...
@@ -545,7 +553,7 @@ void GraphView::Build(ir::Graph* g) {
const
std
::
vector
<
ir
::
Node
*>&
GraphView
::
AllOps
()
{
return
ops_
;
}
const
std
::
vector
<
ir
::
Node
*>&
GraphView
::
AllOps
()
{
return
ops_
;
}
bool
GraphView
::
InSkipSet
(
const
std
::
string
&
var
)
const
{
bool
GraphView
::
InSkipSet
(
const
std
::
string
&
var
)
const
{
return
dup_nodes
_
.
count
(
var
);
return
skip_set
_
.
count
(
var
);
}
}
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/inplace_op_pass.h
浏览文件 @
5ab96d35
...
@@ -57,7 +57,7 @@ class GraphView {
...
@@ -57,7 +57,7 @@ class GraphView {
private:
private:
std
::
vector
<
ir
::
Node
*>
ops_
;
std
::
vector
<
ir
::
Node
*>
ops_
;
std
::
unordered_set
<
std
::
string
>
dup_nodes
_
;
// mem opt affect nodes
std
::
unordered_set
<
std
::
string
>
skip_set
_
;
// mem opt affect nodes
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
std
::
unordered_map
<
ir
::
Node
*
,
uint32_t
>
op_level_
;
std
::
unordered_map
<
ir
::
Node
*
,
uint32_t
>
op_level_
;
};
};
...
...
paddle/fluid/framework/details/memory_optimize_helper.h
浏览文件 @
5ab96d35
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <set>
#include <set>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
...
@@ -30,6 +31,11 @@ namespace paddle {
...
@@ -30,6 +31,11 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
constexpr
char
kMemOptSkipVars
[]
=
"@MEM_OPT_SKIP_VARS@"
;
typedef
std
::
unordered_set
<
std
::
string
>
MemOptSkipVars
;
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
// NOTE(dzh): A ordered set for node reuse in memory optimize.
// NOTE(dzh): A ordered set for node reuse in memory optimize.
...
...
paddle/fluid/framework/details/memory_optimize_pass.cc
浏览文件 @
5ab96d35
...
@@ -45,8 +45,7 @@ namespace framework {
...
@@ -45,8 +45,7 @@ namespace framework {
namespace
details
{
namespace
details
{
void
MemoryOptimizePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
MemoryOptimizePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
auto
nodes
=
graph
->
Nodes
();
CollectSkipVarsSet
(
graph
);
CollectSkipVarsSet
(
nodes
);
cfg_
.
reset
(
new
details
::
ControlFlowGraph
(
*
graph
));
cfg_
.
reset
(
new
details
::
ControlFlowGraph
(
*
graph
));
cfg_
->
LiveVariableAnalysis
();
cfg_
->
LiveVariableAnalysis
();
...
@@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
...
@@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
}
}
}
}
void
MemoryOptimizePass
::
CollectSkipVarsSet
(
void
MemoryOptimizePass
::
CollectSkipVarsSet
(
ir
::
Graph
*
graph
)
const
{
const
std
::
unordered_set
<
ir
::
Node
*>&
nodes
)
const
{
// fill skip_set_
PADDLE_ENFORCE
(
graph
->
Has
(
details
::
kMemOptSkipVars
));
auto
&
mem_opt_whitelist
=
graph
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
for
(
const
auto
&
var
:
mem_opt_whitelist
)
skip_set_
.
emplace
(
var
);
auto
update_skip_set
=
[
&
](
OpDesc
*
op_desc
)
{
auto
update_skip_set
=
[
&
](
OpDesc
*
op_desc
)
{
auto
inputs
=
op_desc
->
InputArgumentNames
();
auto
inputs
=
op_desc
->
InputArgumentNames
();
auto
outputs
=
op_desc
->
OutputArgumentNames
();
auto
outputs
=
op_desc
->
OutputArgumentNames
();
skip_set_
.
insert
(
inputs
.
begin
(),
inputs
.
end
());
skip_set_
.
insert
(
inputs
.
begin
(),
inputs
.
end
());
skip_set_
.
insert
(
outputs
.
begin
(),
outputs
.
end
());
skip_set_
.
insert
(
outputs
.
begin
(),
outputs
.
end
());
};
};
auto
nodes
=
graph
->
Nodes
();
for
(
auto
&
op
:
nodes
)
{
for
(
auto
&
op
:
nodes
)
{
if
(
!
op
->
IsOp
()
||
op
->
Op
()
==
nullptr
)
continue
;
if
(
!
op
->
IsOp
()
||
op
->
Op
()
==
nullptr
)
continue
;
auto
*
op_desc
=
op
->
Op
();
auto
*
op_desc
=
op
->
Op
();
...
...
paddle/fluid/framework/details/memory_optimize_pass.h
浏览文件 @
5ab96d35
...
@@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass {
...
@@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass {
// 1. scan op with subblock and collect the output/input vars.
// 1. scan op with subblock and collect the output/input vars.
// while, while_grad, conditional_block
// while, while_grad, conditional_block
// 2. scan distributed ops and collect the output/input vars
// 2. scan distributed ops and collect the output/input vars
void
CollectSkipVarsSet
(
const
std
::
unordered_set
<
ir
::
Node
*>&
)
const
;
// 3. op_role_vars
void
CollectSkipVarsSet
(
ir
::
Graph
*
graph
)
const
;
private:
private:
// Reuse Node Pool, Owned.
// Reuse Node Pool, Owned.
...
...
paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc
0 → 100644
浏览文件 @
5ab96d35
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
RecordSkipMemoryOptVarsPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
PADDLE_ENFORCE
(
!
graph
->
Has
(
kMemOptSkipVars
));
graph
->
Set
(
kMemOptSkipVars
,
new
MemOptSkipVars
);
auto
&
skip_vars
=
graph
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet
(
graph
,
&
skip_vars
);
}
void
InsertOpRoleVarsToSkipVarSet
(
const
ir
::
Graph
*
graph
,
MemOptSkipVars
*
skip_vars
)
const
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
PADDLE_ENFORCE_NOT_NULL
(
node
,
"The node should not be nullptr."
);
if
(
node
->
IsOp
()
&&
node
->
Op
())
{
try
{
auto
op_role_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
op_role_vars
.
size
()
%
2
,
0
);
for
(
size_t
i
=
0
;
i
<
op_role_vars
.
size
();
i
+=
2
)
{
auto
&
g_name
=
op_role_vars
[
i
+
1
];
skip_vars
->
insert
(
g_name
);
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
}
}
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
record_skip_memory_opt_vars_pass
,
paddle
::
framework
::
details
::
RecordSkipMemoryOptVarsPass
);
paddle/fluid/framework/inplace_op_inference_test.cc
浏览文件 @
5ab96d35
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <vector>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
...
@@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData
(
&
prog
);
FakeSuccData
(
&
prog
);
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
...
@@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
...
@@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData
(
&
prog
);
FakeNoInplaceData
(
&
prog
);
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
...
@@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
...
@@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
,
1024
,
1024
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
,
1024
,
1024
});
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
std
::
unique_ptr
<
details
::
InplacePass
>
pass
(
new
details
::
InplacePass
());
std
::
unique_ptr
<
details
::
InplacePass
>
pass
(
new
details
::
InplacePass
());
pass
->
Apply
(
g
.
get
());
pass
->
Apply
(
g
.
get
());
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_op"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_op"
);
...
@@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
...
@@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
15
,
1024
,
1024
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
15
,
1024
,
1024
});
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
std
::
unique_ptr
<
details
::
InplacePass
>
pass
(
new
details
::
InplacePass
());
std
::
unique_ptr
<
details
::
InplacePass
>
pass
(
new
details
::
InplacePass
());
pass
->
Apply
(
g
.
get
());
pass
->
Apply
(
g
.
get
());
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_grad"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_grad"
);
...
...
paddle/fluid/pybind/const_value.cc
浏览文件 @
5ab96d35
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
...
@@ -28,6 +29,7 @@ void BindConstValue(pybind11::module* m) {
...
@@ -28,6 +29,7 @@ void BindConstValue(pybind11::module* m) {
m
->
def
(
"kControlDepVarName"
,
m
->
def
(
"kControlDepVarName"
,
[]
{
return
framework
::
ir
::
Node
::
kControlDepVarName
;
});
[]
{
return
framework
::
ir
::
Node
::
kControlDepVarName
;
});
m
->
def
(
"kNewGradSuffix"
,
[]
{
return
framework
::
kNewGradSuffix
;
});
m
->
def
(
"kNewGradSuffix"
,
[]
{
return
framework
::
kNewGradSuffix
;
});
m
->
def
(
"kMemOptSkipVars"
,
[]
{
return
framework
::
details
::
kMemOptSkipVars
;
});
auto
op_proto_and_checker_maker
=
auto
op_proto_and_checker_maker
=
m
->
def_submodule
(
"op_proto_and_checker_maker"
);
m
->
def_submodule
(
"op_proto_and_checker_maker"
);
...
...
paddle/fluid/pybind/ir.cc
浏览文件 @
5ab96d35
...
@@ -84,6 +84,12 @@ void BindGraph(py::module *m) {
...
@@ -84,6 +84,12 @@ void BindGraph(py::module *m) {
return
self
.
Set
(
attr_name
,
return
self
.
Set
(
attr_name
,
new
std
::
unordered_set
<
const
Node
*>
(
attr
));
new
std
::
unordered_set
<
const
Node
*>
(
attr
));
})
})
.
def
(
"set"
,
[](
Graph
&
self
,
const
std
::
string
&
attr_name
,
const
std
::
unordered_set
<
std
::
string
>
&
attr
)
{
return
self
.
Set
(
attr_name
,
new
std
::
unordered_set
<
std
::
string
>
(
attr
));
})
.
def
(
"erase"
,
&
Graph
::
Erase
)
.
def
(
"erase"
,
&
Graph
::
Erase
)
.
def
(
"nodes"
,
&
Graph
::
Nodes
,
return_value_policy
::
reference
)
.
def
(
"nodes"
,
&
Graph
::
Nodes
,
return_value_policy
::
reference
)
.
def
(
"create_var_node"
,
.
def
(
"create_var_node"
,
...
...
python/paddle/fluid/compiler.py
浏览文件 @
5ab96d35
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
os
import
six
import
six
...
@@ -152,6 +153,39 @@ class CompiledProgram(object):
...
@@ -152,6 +153,39 @@ class CompiledProgram(object):
else
:
else
:
self
.
_places
=
None
self
.
_places
=
None
self
.
_build_strategy
.
is_distribution
=
_is_pserver_mode
(
self
.
_program
)
self
.
_build_strategy
.
is_distribution
=
_is_pserver_mode
(
self
.
_program
)
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# memory_optimize and enable_inplace default are True, but we can disable them on purpose
if
self
.
_program
:
if
self
.
_program
.
_is_mem_optimized
:
self
.
_build_strategy
.
memory_optimize
=
False
self
.
_build_strategy
.
enable_inplace
=
False
elif
not
self
.
_build_strategy
.
memory_optimize
or
not
self
.
_build_strategy
.
enable_inplace
:
# remind the user to try our memmory optimize strategy
logging
.
warn
(
"""
You can try our memory optimize feature to save your memory usage:
# create a build_strategy variable to set memory optimize option
build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = True
# pass the build_strategy to with_data_parallel API
compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
!!! Memory optimize is our experimental feature !!!
some variables may be removed/reused internal to save memory usage,
in order to fetch the right value of the fetch_list, please set the
persistable property to true for each variable in fetch_list
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
"""
)
return
self
return
self
def
with_inference_optimize
(
self
,
config
):
def
with_inference_optimize
(
self
,
config
):
...
@@ -211,15 +245,6 @@ class CompiledProgram(object):
...
@@ -211,15 +245,6 @@ class CompiledProgram(object):
else
:
else
:
self
.
_exec_strategy
.
num_threads
=
len
(
self
.
_places
)
*
2
self
.
_exec_strategy
.
num_threads
=
len
(
self
.
_places
)
*
2
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# memory_optimize and enable_inplace default are True, but we can disable them on purpose
if
self
.
_program
and
self
.
_program
.
_is_mem_optimized
:
self
.
_build_strategy
.
memory_optimize
=
False
if
self
.
_program
and
self
.
_program
.
_is_mem_optimized
:
self
.
_build_strategy
.
enable_inplace
=
False
# TODO(wuyi): trainer endpoings should be passed in through
# TODO(wuyi): trainer endpoings should be passed in through
# build_strategy, not program.xxx.
# build_strategy, not program.xxx.
if
self
.
_program
and
self
.
_build_strategy
.
num_trainers
>
1
and
\
if
self
.
_program
and
self
.
_build_strategy
.
num_trainers
>
1
and
\
...
...
python/paddle/fluid/executor.py
浏览文件 @
5ab96d35
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
logging
import
os
import
os
import
multiprocessing
import
multiprocessing
import
numpy
as
np
import
numpy
as
np
...
@@ -449,6 +450,36 @@ class Executor(object):
...
@@ -449,6 +450,36 @@ class Executor(object):
return
as_numpy
(
arr
)
return
as_numpy
(
arr
)
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
def
_check_fetch_vars_persistable
(
self
,
program
,
fetch_list
):
for
var
in
fetch_list
:
if
isinstance
(
var
,
Variable
):
persistable
=
var
.
persistable
else
:
block_num
=
program
.
desc
.
num_blocks
()
persistable
=
None
var_name
=
cpt
.
to_bytes
(
var
)
for
i
in
six
.
moves
.
range
(
block_num
):
var_desc
=
program
.
desc
.
block
(
i
).
find_var
(
var_name
)
if
var_desc
:
persistable
=
var_desc
.
persistable
()
break
assert
persistable
is
not
None
,
"Variable {} is not found"
.
format
(
var
)
if
not
persistable
:
logging
.
warn
(
"""
Detect that memory optimize or inplace is enabled, but the some variables in the fetch
list is not persistable, you may get wrong fetched value, or an exeception may be thrown
about cannot find variable of the fetch list.
TO FIX this:
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
"""
)
def
run
(
self
,
def
run
(
self
,
program
=
None
,
program
=
None
,
feed
=
None
,
feed
=
None
,
...
@@ -532,6 +563,11 @@ class Executor(object):
...
@@ -532,6 +563,11 @@ class Executor(object):
scope
=
scope
,
scope
=
scope
,
return_numpy
=
return_numpy
,
return_numpy
=
return_numpy
,
use_program_cache
=
use_program_cache
)
use_program_cache
=
use_program_cache
)
else
:
if
fetch_list
and
program
.
_is_data_parallel
and
program
.
_program
and
(
program
.
_build_strategy
.
memory_optimize
or
program
.
_build_strategy
.
enable_inplace
):
self
.
_check_fetch_vars_persistable
(
program
.
_program
,
fetch_list
)
program
.
_compile
(
scope
,
self
.
place
)
program
.
_compile
(
scope
,
self
.
place
)
if
program
.
_is_data_parallel
:
if
program
.
_is_data_parallel
:
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
5ab96d35
...
@@ -57,12 +57,15 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -57,12 +57,15 @@ class TestParallelExecutorBase(unittest.TestCase):
startup
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
.
random_seed
=
1
# Fix random seed
startup
.
random_seed
=
1
# Fix random seed
main
.
random_seed
=
1
main
.
random_seed
=
1
with
fluid
.
program_guard
(
main
,
startup
):
with
fluid
.
program_guard
(
main
,
startup
):
if
seed
is
not
None
:
if
seed
is
not
None
:
startup
.
random_seed
=
seed
startup
.
random_seed
=
seed
main
.
random_seed
=
seed
main
.
random_seed
=
seed
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
loss
.
persistable
=
True
if
optimizer
:
if
optimizer
:
optimizer
().
minimize
(
loss
)
optimizer
().
minimize
(
loss
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录