Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
381f2015
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
381f2015
编写于
2月 12, 2019
作者:
D
dzhwinter
提交者:
GitHub
2月 12, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15665 from dzhwinter/experiment/refactor_memory
refactor optimize pass.
上级
6492ea9c
04e9776a
变更
23
显示空白变更内容
内联
并排
Showing
23 changed file
with
842 addition
and
1027 deletion
+842
-1027
cmake/flags.cmake
cmake/flags.cmake
+2
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-6
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+0
-2
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+0
-3
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+6
-7
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+8
-7
paddle/fluid/framework/details/memory_early_delete_pass.cc
paddle/fluid/framework/details/memory_early_delete_pass.cc
+0
-117
paddle/fluid/framework/details/memory_early_delete_pass.h
paddle/fluid/framework/details/memory_early_delete_pass.h
+0
-32
paddle/fluid/framework/details/memory_optimize_helper.cc
paddle/fluid/framework/details/memory_optimize_helper.cc
+306
-30
paddle/fluid/framework/details/memory_optimize_helper.h
paddle/fluid/framework/details/memory_optimize_helper.h
+83
-36
paddle/fluid/framework/details/memory_optimize_helper_test.cc
...le/fluid/framework/details/memory_optimize_helper_test.cc
+408
-9
paddle/fluid/framework/details/memory_optimize_pass.cc
paddle/fluid/framework/details/memory_optimize_pass.cc
+8
-289
paddle/fluid/framework/details/memory_optimize_pass.h
paddle/fluid/framework/details/memory_optimize_pass.h
+3
-47
paddle/fluid/framework/details/memory_optimize_pass_test.cc
paddle/fluid/framework/details/memory_optimize_pass_test.cc
+0
-417
paddle/fluid/framework/details/sequential_execution_pass.cc
paddle/fluid/framework/details/sequential_execution_pass.cc
+1
-0
paddle/fluid/framework/details/sequential_execution_pass.h
paddle/fluid/framework/details/sequential_execution_pass.h
+0
-2
paddle/fluid/framework/inplace_op_inference.h
paddle/fluid/framework/inplace_op_inference.h
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+2
-9
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+1
-5
paddle/fluid/memory/allocation/legacy_allocator.cc
paddle/fluid/memory/allocation/legacy_allocator.cc
+3
-2
paddle/fluid/platform/place.cc
paddle/fluid/platform/place.cc
+6
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-4
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+2
-1
未找到文件。
cmake/flags.cmake
浏览文件 @
381f2015
...
@@ -27,6 +27,7 @@ endfunction()
...
@@ -27,6 +27,7 @@ endfunction()
CheckCompilerCXX11Flag
()
CheckCompilerCXX11Flag
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-m64"
)
# safe_set_flag
# safe_set_flag
#
#
# Set a compile flag only if compiler is support
# Set a compile flag only if compiler is support
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
381f2015
...
@@ -54,8 +54,6 @@ cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph grap
...
@@ -54,8 +54,6 @@ cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph grap
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
...
@@ -67,13 +65,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
...
@@ -67,13 +65,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass
memory_early_delete_pass
inplace_op_pass
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
endif
()
endif
()
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph
)
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry
)
cc_test
(
memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
381f2015
...
@@ -206,8 +206,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -206,8 +206,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
all_op_descs
);
// take ownership
all_op_descs
);
// take ownership
graph
->
Set
<
GraphNodePool
>
(
kGraphNodePool
,
new
GraphNodePool
);
// take ownership
pass
->
Erase
(
kAllOpDescs
);
pass
->
Erase
(
kAllOpDescs
);
pass
->
SetNotOwned
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
all_op_descs
);
pass
->
SetNotOwned
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
all_op_descs
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
381f2015
...
@@ -77,9 +77,6 @@ struct BuildStrategy {
...
@@ -77,9 +77,6 @@ struct BuildStrategy {
bool
fuse_relu_depthwise_conv_
{
false
};
bool
fuse_relu_depthwise_conv_
{
false
};
bool
memory_optimize_
{
false
};
bool
memory_optimize_
{
false
};
bool
memory_early_delete_
{
false
};
// TODO(dzhwinter):
// TODO(dzhwinter):
// make enable_inplace, memory_optimize_
// make enable_inplace, memory_optimize_
// memory_early_delete_ true by default
// memory_early_delete_ true by default
...
...
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
381f2015
...
@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
...
@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
}
}
}
}
const
SSANodePair
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
NodeSwapQueue
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
cache_var
,
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
var_desc
->
SetName
(
cache_var
);
SSANodePair
swap_nodes
;
NodeSwapQueue
swap_nodes
;
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
auto
*
op
=
view_
.
AllOps
()[
i
];
...
@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
...
@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
return
swap_nodes
;
return
swap_nodes
;
}
}
void
InplacePass
::
CommitModify
(
const
SSANodePair
&
swap_nodes
,
void
InplacePass
::
CommitModify
(
const
NodeSwapQueue
&
swap_nodes
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
swap_nodes
)
{
for
(
auto
&
pair
:
swap_nodes
)
{
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
...
@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes,
...
@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes,
}
}
}
}
void
InplacePass
::
WithdrawModify
(
const
SSANodePair
&
nodes
,
void
InplacePass
::
WithdrawModify
(
const
NodeSwapQueue
&
nodes
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
nodes
)
{
for
(
auto
&
pair
:
nodes
)
{
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
...
...
paddle/fluid/framework/details/inplace_op_pass.h
浏览文件 @
381f2015
...
@@ -56,7 +56,8 @@ class GraphView {
...
@@ -56,7 +56,8 @@ class GraphView {
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
};
};
typedef
std
::
vector
<
std
::
pair
<
ir
::
Node
*
,
ir
::
Node
*>>
SSANodePair
;
// swap pairs in sequence
typedef
std
::
vector
<
std
::
pair
<
ir
::
Node
*
,
ir
::
Node
*>>
NodeSwapQueue
;
class
InplacePass
:
public
ir
::
Pass
{
class
InplacePass
:
public
ir
::
Pass
{
public:
public:
InplacePass
();
InplacePass
();
...
@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass {
...
@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass {
void
InitSSAGraphNodes
()
const
;
void
InitSSAGraphNodes
()
const
;
private:
private:
const
SSANodePair
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
NodeSwapQueue
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
SSANodePair
&
,
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
NodeSwapQueue
&
,
ir
::
Graph
*
graph
)
const
;
void
WithdrawModify
(
const
SSANodePair
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
WithdrawModify
(
const
NodeSwapQueue
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
)
const
;
const
size_t
&
idx
)
const
;
...
...
paddle/fluid/framework/details/memory_early_delete_pass.cc
已删除
100644 → 0
浏览文件 @
6492ea9c
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
ComputationOpHandle
*
FindNextComputationOpHandle
(
VarHandle
*
var_in
)
{
std
::
queue
<
VarHandleBase
*>
queue
;
queue
.
push
(
var_in
);
do
{
auto
*
var
=
queue
.
front
();
queue
.
pop
();
for
(
auto
*
op
:
var
->
PendingOps
())
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetPlace
()
==
var_in
->
place
())
{
return
compute_op
;
}
for
(
auto
*
out_var
:
op
->
Outputs
())
{
queue
.
push
(
out_var
);
}
}
}
while
(
!
queue
.
empty
());
return
nullptr
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MemoryEarlyDeletePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
graph_pool
=
Get
<
GraphNodePool
>
(
kGraphNodePool
);
auto
&
gcs
=
Get
<
GarbageCollectorMap
>
(
kGarbageCollector
);
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
OpDesc
*>>
unlived_vars
;
unlived_vars
.
reserve
(
graph_pool
.
size
());
for
(
auto
&
pair
:
graph_pool
)
{
unlived_vars
.
insert
(
std
::
make_pair
(
pair
.
first
,
pair
.
second
));
}
auto
compare_and_insert_early_delete_op
=
[
&
](
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>&
vars
)
{
if
(
unlived_vars
.
empty
())
return
;
// unlived vars can be deleted after the last used op has finished.
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
const
auto
&
places
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
kAllPlaces
);
for
(
auto
&
var
:
vars
)
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var
);
auto
var_name
=
var
->
Node
()
->
Name
();
auto
&
var_place
=
var_handle
->
place
();
if
(
unlived_vars
.
count
(
var_name
)
==
0
)
continue
;
if
(
!
unlived_vars
[
var_name
].
empty
())
{
if
(
compute_op
!=
nullptr
&&
unlived_vars
[
var_name
].
count
(
compute_op
->
Node
()
->
Op
())
!=
0
)
{
unlived_vars
[
var_name
].
erase
(
compute_op
->
Node
()
->
Op
());
}
continue
;
}
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
()
||
var_handle
->
Node
()
->
IsCtrlVar
())
continue
;
// shameless copyed from reference count pass.
if
(
compute_op
==
nullptr
)
{
// use next computation op scope
compute_op
=
FindNextComputationOpHandle
(
var_handle
);
}
auto
*
early_delete_node
=
graph
->
CreateEmptyNode
(
"early_delete"
,
ir
::
Node
::
Type
::
kOperation
);
GarbageCollector
*
gc
=
gcs
.
at
(
places
[
compute_op
->
GetScopeIdx
()]).
get
();
auto
*
early_delete_handle
=
new
EarlyDeleteOpHandle
(
early_delete_node
,
compute_op
->
GetScope
(),
var_place
,
{
var_name
},
gc
);
if
(
compute_op
->
Outputs
().
empty
())
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
early_delete_handle
->
AddInput
(
compute_op
->
Outputs
().
front
());
VLOG
(
5
)
<<
"Add early delete op "
<<
var_name
<<
" to Operator"
<<
compute_op
->
Name
();
}
};
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
for
(
auto
&
op
:
all_ops
)
{
compare_and_insert_early_delete_op
(
op
,
op
->
Inputs
());
compare_and_insert_early_delete_op
(
op
,
op
->
Outputs
());
}
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
memory_early_delete_pass
,
paddle
::
framework
::
details
::
MemoryEarlyDeletePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/memory_early_delete_pass.h
已删除
100644 → 0
浏览文件 @
6492ea9c
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/early_delete_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
MemoryEarlyDeletePass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/memory_optimize_helper.cc
浏览文件 @
381f2015
...
@@ -13,17 +13,108 @@
...
@@ -13,17 +13,108 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <deque>
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
using
paddle
::
framework
::
VarDesc
;
size_t
NodeSizeInBytes
(
const
VarDesc
&
node
)
{
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
"Graph has no attribute of kAllOpDescs."
);
// 1. get op desc order
auto
&
op_descs
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
);
// 2. topology sort order
auto
nodes
=
graph
.
Nodes
();
std
::
deque
<
ir
::
Node
*>
ops
;
FilterVariables
(
nodes
,
[
&
](
ir
::
Node
*
op
)
{
if
(
op
->
IsOp
()
&&
op
->
Op
()
!=
nullptr
)
{
ops
.
emplace_back
(
op
);
}
});
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
list
<
ir
::
Node
*>
ready_ops
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
for
(
auto
*
op
:
ops
)
{
std
::
unordered_set
<
ir
::
Node
*>
preceding_op
;
for
(
auto
*
in
:
op
->
inputs
)
{
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
());
preceding_op
.
emplace
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
emplace
(
op
);
}
op_deps
[
op
]
=
preceding_op
.
size
();
if
(
preceding_op
.
empty
())
{
ready_ops
.
emplace_back
(
op
);
}
}
// 3. generated op list based desc order and the topology order
std
::
vector
<
ir
::
Node
*>
ret
;
std
::
list
<
OpDesc
*>
op_descs_list
(
op_descs
.
begin
(),
op_descs
.
end
());
auto
update_by_found_node
=
[
&
](
ir
::
Node
*
found_node
)
{
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
[
pending_op
]
==
0
)
{
ready_ops
.
emplace_back
(
pending_op
);
}
}
ready_ops
.
remove
(
found_node
);
ret
.
emplace_back
(
found_node
);
};
while
(
!
ready_ops
.
empty
())
{
bool
all_of_ready_op_unmatched
=
true
;
for
(
auto
it
=
op_descs_list
.
begin
();
it
!=
op_descs_list
.
end
();)
{
auto
op_desc
=
*
it
;
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
op
:
ready_ops
)
{
if
(
IsSameDesc
(
op
->
Op
(),
op_desc
))
{
found_node
=
op
;
break
;
}
}
// 3.1 op desc deleted by other pass
if
(
found_node
==
nullptr
)
{
++
it
;
continue
;
}
else
{
all_of_ready_op_unmatched
=
false
;
it
=
op_descs_list
.
erase
(
it
);
}
update_by_found_node
(
found_node
);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std
::
list
<
ir
::
Node
*>
prev_ready_ops
(
ready_ops
);
if
(
all_of_ready_op_unmatched
)
{
for
(
auto
op
:
prev_ready_ops
)
{
update_by_found_node
(
op
);
}
}
}
PADDLE_ENFORCE
(
std
::
all_of
(
op_deps
.
begin
(),
op_deps
.
end
(),
[
&
](
const
std
::
pair
<
ir
::
Node
*
,
size_t
>&
p
)
{
return
p
.
second
==
0
;
}));
return
ret
;
}
size_t
NodeSize
(
const
VarDesc
&
node
)
{
auto
shape
=
node
.
GetShape
();
auto
shape
=
node
.
GetShape
();
int
size
=
int
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
...
@@ -31,9 +122,9 @@ size_t NodeSizeInBytes(const VarDesc& node) {
...
@@ -31,9 +122,9 @@ size_t NodeSizeInBytes(const VarDesc& node) {
return
type_size
*
std
::
abs
(
size
);
return
type_size
*
std
::
abs
(
size
);
}
}
size_t
NodeSize
InBytes
(
ir
::
Node
*
n
)
{
size_t
NodeSize
(
ir
::
Node
*
n
)
{
auto
*
desc
=
FindVarDescInBlock
(
n
);
auto
*
desc
=
FindVarDescInBlock
(
n
);
return
NodeSize
InBytes
(
*
desc
);
return
NodeSize
(
*
desc
);
}
}
std
::
string
DebugStringImpl
(
VarDesc
*
var
)
{
std
::
string
DebugStringImpl
(
VarDesc
*
var
)
{
...
@@ -59,7 +150,6 @@ std::string DebugStringImpl(VarDesc* var) {
...
@@ -59,7 +150,6 @@ std::string DebugStringImpl(VarDesc* var) {
std
::
string
DebugString
(
ir
::
Node
*
var
)
{
std
::
string
DebugString
(
ir
::
Node
*
var
)
{
return
DebugStringImpl
(
FindVarDescInBlock
(
var
));
return
DebugStringImpl
(
FindVarDescInBlock
(
var
));
}
}
// return DebugString(var->Var()); }
// NOTE(dzh): based ir node, if a large node has been reused
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// by a small size node, then next time it appear in pool, it will
...
@@ -80,18 +170,17 @@ struct NodeComparator {
...
@@ -80,18 +170,17 @@ struct NodeComparator {
auto
rhs_shape
=
rhs_desc
->
GetShape
();
auto
rhs_shape
=
rhs_desc
->
GetShape
();
if
((
lhs_shape
[
0
]
==
-
1
&&
rhs_shape
[
0
]
==
-
1
)
||
if
((
lhs_shape
[
0
]
==
-
1
&&
rhs_shape
[
0
]
==
-
1
)
||
(
lhs_shape
[
0
]
!=
-
1
&&
rhs_shape
[
0
]
!=
-
1
))
{
(
lhs_shape
[
0
]
!=
-
1
&&
rhs_shape
[
0
]
!=
-
1
))
{
return
NodeSize
InBytes
(
lhs
)
<=
NodeSizeInBytes
(
rhs
);
return
NodeSize
(
lhs
)
<=
NodeSize
(
rhs
);
}
else
{
}
else
{
return
false
;
return
false
;
}
}
}
}
};
};
void
Ordered
NodeList
::
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
)
{
void
Ordered
Set
::
Insert
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
mark_table_
[
var
->
Name
()]
->
second
.
insert
(
op
);
mark_table_
[
var
->
Name
()]
->
emplace_back
(
var
);
return
;
return
;
}
}
...
@@ -99,14 +188,15 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
...
@@ -99,14 +188,15 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
auto
var_shape
=
var_desc
->
GetShape
();
auto
var_shape
=
var_desc
->
GetShape
();
int
batch_size
=
static_cast
<
int
>
(
var_shape
[
0
]);
int
batch_size
=
static_cast
<
int
>
(
var_shape
[
0
]);
NodeComparator
compare_node
;
NodeComparator
functor
;
Iter
it
=
nodes_
.
begin
();
Iter
it
=
nodes_
.
begin
();
while
(
it
!=
nodes_
.
end
())
{
while
(
it
!=
nodes_
.
end
())
{
auto
*
cache_desc
=
FindVarDescInBlock
(
it
->
first
);
auto
&
prev
=
it
->
front
();
auto
*
cache_desc
=
FindVarDescInBlock
(
prev
);
int
cache_batch_size
=
cache_desc
->
GetShape
()[
0
];
int
cache_batch_size
=
cache_desc
->
GetShape
()[
0
];
if
((
cache_batch_size
==
-
1
&&
batch_size
==
-
1
)
||
if
((
cache_batch_size
==
-
1
&&
batch_size
==
-
1
)
||
(
cache_batch_size
!=
-
1
&&
batch_size
!=
-
1
))
{
(
cache_batch_size
!=
-
1
&&
batch_size
!=
-
1
))
{
if
(
compare_node
(
it
->
first
,
var
))
{
if
(
functor
(
prev
,
var
))
{
++
it
;
++
it
;
}
else
{
}
else
{
break
;
break
;
...
@@ -118,62 +208,80 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
...
@@ -118,62 +208,80 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
}
}
}
}
it
=
it
=
nodes_
.
insert
(
it
,
{
var
});
nodes_
.
insert
(
it
,
std
::
make_pair
(
var
,
std
::
unordered_set
<
ir
::
Node
*>
{
op
}));
mark_table_
[
var
->
Name
()]
=
it
;
mark_table_
[
var
->
Name
()]
=
it
;
}
}
int
Ordered
NodeList
::
GetIndex
(
ir
::
Node
*
var
)
{
int
Ordered
Set
::
GetNodeIndexInPool
(
ir
::
Node
*
var
)
{
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
}
}
ir
::
Node
*
Ordered
NodeList
::
NodeMatch
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
Ordered
Set
::
FindBestFitNode
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
found_node
=
nullptr
;
ir
::
Node
*
found_node
=
nullptr
;
NodeComparator
compare_node
;
NodeComparator
functor
;
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
if
(
compare_node
(
var
,
it
->
first
))
{
auto
&
candidate
=
it
->
front
();
found_node
=
it
->
first
;
if
(
functor
(
var
,
candidate
))
{
found_node
=
candidate
;
break
;
break
;
}
}
}
}
return
found_node
;
return
found_node
;
}
}
void
OrderedNodeList
::
Erase
(
ir
::
Node
*
var
)
{
Erase
(
var
->
Name
());
}
bool
OrderedSet
::
Has
(
ir
::
Node
*
var
)
const
{
if
(
mark_table_
.
count
(
var
->
Name
()))
{
auto
&
node_in_samename
=
mark_table_
.
at
(
var
->
Name
());
auto
iter
=
std
::
find_if
(
node_in_samename
->
begin
(),
node_in_samename
->
end
(),
[
&
](
ir
::
Node
*
n
)
{
return
n
->
Name
()
==
var
->
Name
();
});
return
iter
!=
node_in_samename
->
end
();
}
return
false
;
}
void
Ordered
NodeList
::
Erase
(
const
std
::
string
&
var
)
{
void
Ordered
Set
::
Erase
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
));
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
->
Name
()
));
nodes_
.
erase
(
mark_table_
[
var
]);
nodes_
.
erase
(
mark_table_
[
var
->
Name
()
]);
mark_table_
.
erase
(
var
);
mark_table_
.
erase
(
var
->
Name
()
);
}
}
std
::
string
Ordered
NodeLis
t
::
ToString
()
const
{
std
::
string
Ordered
Se
t
::
ToString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
ss
<<
DebugString
(
it
->
first
)
<<
" "
;
for
(
auto
&
node
:
*
it
)
{
ss
<<
DebugString
(
node
)
<<
" "
;
}
}
}
return
ss
.
str
();
return
ss
.
str
();
}
}
bool
NodeCanReused
(
ir
::
Node
*
node
)
{
bool
NodeCanReused
(
ir
::
Node
*
node
)
{
// valid the node is a var node
if
(
node
==
nullptr
||
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
if
(
node
==
nullptr
||
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
// auto* desc = node->Var();
bool
flag
=
NodeCanReused
(
*
node
->
Var
());
bool
flag
=
true
;
// op output force generated in cpu, can not be reused.
for
(
auto
*
op
:
node
->
inputs
)
{
for
(
auto
*
op
:
node
->
inputs
)
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
// op output force generated in cpu, can not be reused.
flag
&=
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
flag
&=
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
}
}
}
}
// var desc validation.
flag
&=
NodeCanReused
(
*
node
->
Var
());
return
flag
;
return
flag
;
}
}
bool
NodeCanReused
(
const
VarDesc
&
node
)
{
bool
NodeCanReused
(
const
VarDesc
&
node
)
{
auto
type
=
node
.
GetType
();
auto
type
=
node
.
GetType
();
if
(
node
.
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
if
(
!
(
type
==
proto
::
VarType
::
LOD_TENSOR
||
node
.
GetShape
().
empty
())
{
type
==
proto
::
VarType
::
SELECTED_ROWS
||
type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
))
{
return
false
;
}
if
(
node
.
Persistable
()
||
node
.
GetShape
().
empty
())
{
return
false
;
return
false
;
}
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
...
@@ -193,6 +301,174 @@ bool OpHasSubBlock(OpDesc* desc) {
...
@@ -193,6 +301,174 @@ bool OpHasSubBlock(OpDesc* desc) {
return
false
;
return
false
;
}
}
ControlFlowGraph
::
ControlFlowGraph
(
const
ir
::
Graph
&
graph
)
{
ops_
=
SortOpLikeDescOrder
(
graph
);
ConnectNodes
();
}
void
ControlFlowGraph
::
BuildCFGGraph
()
{
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for
(
ir
::
Node
*
op
:
ops_
)
{
for
(
auto
&
input_var
:
op
->
inputs
)
{
if
(
!
input_var
->
inputs
.
empty
())
{
PADDLE_ENFORCE
(
input_var
->
inputs
.
size
()
==
1
&&
input_var
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
auto
*
pred_op
=
input_var
->
inputs
[
0
];
if
(
pred_op
->
Op
()
!=
nullptr
)
{
predecessors_
[
op
].
insert
(
pred_op
);
successors_
[
pred_op
].
insert
(
op
);
}
}
if
(
input_var
->
IsVar
()
&&
!
input_var
->
IsCtrlVar
())
{
uses_
[
op
].
insert
(
input_var
->
Name
());
}
}
for
(
auto
&
output_var
:
op
->
outputs
)
{
// output var may be used by many op
for
(
auto
*
succ_op
:
output_var
->
outputs
)
{
if
(
succ_op
->
Op
()
!=
nullptr
)
{
successors_
[
op
].
insert
(
succ_op
);
predecessors_
[
succ_op
].
insert
(
op
);
}
}
if
(
output_var
->
IsVar
()
&&
!
output_var
->
IsCtrlVar
())
{
defs_
[
op
].
insert
(
output_var
->
Name
());
}
}
}
}
void
ControlFlowGraph
::
ConnectNodes
()
{
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
auto
&
op
=
ops_
[
i
];
try
{
auto
&
next_op
=
ops_
.
at
(
i
+
1
);
successors_
[
op
].
insert
(
next_op
);
predecessors_
[
next_op
].
insert
(
op
);
}
catch
(...)
{
// do nothing
}
FilterVariables
(
op
->
inputs
,
[
&
](
ir
::
Node
*
var
)
{
uses_
[
op
].
emplace
(
var
->
Name
());
});
FilterVariables
(
op
->
outputs
,
[
&
](
ir
::
Node
*
var
)
{
defs_
[
op
].
emplace
(
var
->
Name
());
});
}
}
void
ControlFlowGraph
::
LiveVariableAnalysis
()
{
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std
::
list
<
ir
::
Node
*>
work_list
(
ops_
.
rbegin
(),
ops_
.
rend
());
while
(
!
work_list
.
empty
())
{
ir
::
Node
*
op
=
work_list
.
front
();
work_list
.
pop_front
();
// get the live_in calculated before. Empty if first.
auto
prev_live_in
=
std
::
move
(
live_in_
[
op
]);
for
(
auto
&
s
:
successors_
[
op
])
{
for
(
auto
&
var
:
live_in_
[
s
])
{
live_out_
[
op
].
insert
(
var
);
}
}
for
(
auto
&
var
:
uses_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
live_out_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
defs_
[
op
])
{
live_in_
[
op
].
erase
(
var
);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if
(
live_in_
[
op
]
!=
prev_live_in
)
{
for
(
auto
&
pre
:
predecessors_
[
op
])
{
work_list
.
push_back
(
pre
);
}
}
}
}
void
ControlFlowGraph
::
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
)
{
// update graph from begin idx to the end
for
(
size_t
i
=
begin_idx
;
i
!=
ops_
.
size
();
++
i
)
{
auto
*
op
=
ops_
[
i
];
if
(
uses_
[
op
].
find
(
old_node
)
!=
uses_
[
op
].
end
())
{
uses_
[
op
].
erase
(
old_node
);
uses_
[
op
].
insert
(
new_node
);
}
if
(
defs_
[
op
].
find
(
old_node
)
!=
defs_
[
op
].
end
())
{
defs_
[
op
].
erase
(
old_node
);
defs_
[
op
].
insert
(
new_node
);
}
if
(
live_in_
[
op
].
find
(
old_node
)
!=
live_in_
[
op
].
end
())
{
live_in_
[
op
].
erase
(
old_node
);
live_in_
[
op
].
insert
(
new_node
);
}
if
(
live_out_
[
op
].
find
(
old_node
)
!=
live_out_
[
op
].
end
())
{
live_out_
[
op
].
erase
(
old_node
);
live_out_
[
op
].
insert
(
new_node
);
}
}
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveIn
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_in_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_in_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_in, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveOut
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_out_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_out_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
Use
(
ir
::
Node
*
op
)
const
{
auto
it
=
uses_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
uses_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
vector
<
ir
::
Node
*>
ControlFlowGraph
::
Ops
()
const
{
return
ops_
;
}
std
::
vector
<
ir
::
Node
*>&
ControlFlowGraph
::
Ops
()
{
return
ops_
;
}
ir
::
Node
*
ControlFlowGraph
::
GetNodeByName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
{
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ops_
)
{
if
(
node
==
op
)
break
;
for
(
auto
&
output
:
node
->
outputs
)
{
if
(
output
->
Name
()
==
name
)
{
found_node
=
output
;
}
}
}
return
found_node
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/memory_optimize_helper.h
浏览文件 @
381f2015
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
#include <list>
#include <list>
#include <map>
#include <set>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
...
@@ -27,41 +29,41 @@ namespace paddle {
...
@@ -27,41 +29,41 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kFetchedVars
[]
=
"fetched_vars"
;
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
constexpr
char
kGraphNodePool
[]
=
"graph_node_pool"
;
// NOTE(dzh): Variable and the operators use the var.
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
// for early delete pass.
// Because analysis var pass build base on ir::Node, which maybe released
// or modified between passes, so we use OpDesc* to mark ops.
using
GraphNodePool
=
std
::
vector
<
std
::
pair
<
std
::
string
/*var node*/
,
std
::
unordered_set
<
OpDesc
*>
/* ops */
>>
;
// NOTE(dzh): by default, it sort node in ascend order(by node bytes size).
// NOTE(dzh): A ordered set for node reuse in memory optimize.
// in fluid, -1 means the batch_size is determined in runtime.
// the orderedset sort node in ascend order(by node bytes size).
// the node batch_size equal -1 always ranking in the front than the node not.
// in fluid, -1 means the batch_size, which is determined in runtime.
// So the reuse happens between nodes who's batch_size both are -1
// simultaneously or not.
//
// sort rule:
// rule 0 : smaller node ranking in front.
// rule 1 : batch_size equal -1 ranking in the front than the node not.
//
// For example,
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
class
OrderedNodeList
{
public:
using
NodePair
=
std
::
pair
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
;
using
Iter
=
typename
std
::
list
<
NodePair
>::
iterator
;
using
ConstIter
=
typename
std
::
list
<
NodePair
>::
const_iterator
;
void
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
);
class
OrderedSet
{
public:
// nodes with same name exists in pool.
using
NodeVector
=
std
::
vector
<
ir
::
Node
*>
;
using
Iter
=
typename
std
::
list
<
NodeVector
>::
iterator
;
using
ConstIter
=
typename
std
::
list
<
NodeVector
>::
const_iterator
;
void
Insert
(
ir
::
Node
*
var
);
void
Erase
(
ir
::
Node
*
var
);
void
Erase
(
ir
::
Node
*
var
);
bool
Has
(
ir
::
Node
*
var
)
const
;
void
Erase
(
const
std
::
string
&
var
);
void
Clear
()
{
mark_table_
.
clear
();
bool
Has
(
ir
::
Node
*
var
)
{
return
mark_table_
.
count
(
var
->
Name
());
}
nodes_
.
clear
();
}
bool
Has
(
const
std
::
string
&
var
)
{
return
mark_table_
.
count
(
var
);
}
// find the bestfit shape node block with var.
ir
::
Node
*
FindBestFitNode
(
ir
::
Node
*
var
)
const
;
ir
::
Node
*
NodeMatch
(
ir
::
Node
*
var
)
const
;
// map store non-const iterator, can not promise const
// map store non-const iterator, can not promise const
int
Get
Index
(
ir
::
Node
*
var
);
int
Get
NodeIndexInPool
(
ir
::
Node
*
var
);
// pool all node to string
// pool all node to string
std
::
string
ToString
()
const
;
std
::
string
ToString
()
const
;
...
@@ -69,18 +71,54 @@ class OrderedNodeList {
...
@@ -69,18 +71,54 @@ class OrderedNodeList {
Iter
end
()
{
return
nodes_
.
end
();
}
Iter
end
()
{
return
nodes_
.
end
();
}
ConstIter
begin
()
const
{
return
nodes_
.
begin
();
}
ConstIter
begin
()
const
{
return
nodes_
.
begin
();
}
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
size_t
size
()
const
{
return
nodes_
.
size
();
}
void
Clear
()
{
size_t
size
()
const
{
return
nodes_
.
size
();
}
mark_table_
.
clear
();
nodes_
.
clear
();
}
private:
private:
// for searching.
// for searching.
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
// node swap pairs. var -> ops dep var
// node pool
std
::
list
<
NodePair
>
nodes_
;
std
::
list
<
NodeVector
>
nodes_
;
};
class
ControlFlowGraph
{
public:
ControlFlowGraph
()
=
default
;
// IR Graph
explicit
ControlFlowGraph
(
const
ir
::
Graph
&
graph
);
void
LiveVariableAnalysis
();
void
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
);
const
std
::
set
<
std
::
string
>
LiveIn
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
LiveOut
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
Use
(
ir
::
Node
*
op
)
const
;
const
std
::
vector
<
ir
::
Node
*>
Ops
()
const
;
std
::
vector
<
ir
::
Node
*>&
Ops
();
// for ssa-graph nodes
ir
::
Node
*
GetNodeByName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
;
private:
void
BuildCFGGraph
();
void
ConnectNodes
();
using
NodeListMap
=
std
::
unordered_map
<
ir
::
Node
*
,
std
::
set
<
ir
::
Node
*>>
;
using
VarSetMap
=
std
::
map
<
ir
::
Node
*
,
std
::
set
<
std
::
string
>>
;
// successors ops use the output variables.
NodeListMap
successors_
;
// predecessors ops generated input variables.
NodeListMap
predecessors_
;
// variables lived before run current op.
VarSetMap
live_in_
;
// variables lived after run current op.
VarSetMap
live_out_
;
VarSetMap
uses_
;
// op inputs
VarSetMap
defs_
;
// op outputs
std
::
vector
<
ir
::
Node
*>
ops_
;
// op sequence by topology sort
};
};
// valid a tensor can be reuse or not
// valid a tensor can be reuse or not
...
@@ -93,15 +131,24 @@ bool NodeCanReused(const VarDesc& node);
...
@@ -93,15 +131,24 @@ bool NodeCanReused(const VarDesc& node);
bool
OpHasSubBlock
(
OpDesc
*
desc
);
bool
OpHasSubBlock
(
OpDesc
*
desc
);
// node memory size in bytes
// node memory size in bytes
size_t
NodeSize
InBytes
(
ir
::
Node
*
n
);
size_t
NodeSize
(
ir
::
Node
*
n
);
// node memory size in bytes
// node memory size in bytes
size_t
NodeSize
InBytes
(
const
VarDesc
&
);
size_t
NodeSize
(
const
VarDesc
&
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
// NOTE(dzhwinter)
// after node reuse, the replaced node shape is
// different with its VarDesc. So need to find the
// correct VarDesc in Block.
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
static
inline
bool
IsSameDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
template
<
typename
Container
,
typename
Callback
>
template
<
typename
Container
,
typename
Callback
>
class
FilterVariableImpl
{
class
FilterVariableImpl
{
public:
public:
...
...
paddle/fluid/framework/details/memory_optimize_helper_test.cc
浏览文件 @
381f2015
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
#include <iterator>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
...
@@ -22,13 +23,19 @@
...
@@ -22,13 +23,19 @@
#include <vector>
#include <vector>
#include "glog/logging.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
TEST
(
Ordered
NodeLis
t
,
Normal
)
{
TEST
(
Ordered
Se
t
,
Normal
)
{
Ordered
NodeLis
t
pool
;
Ordered
Se
t
pool
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
// clang-format off
// clang-format off
...
@@ -56,8 +63,15 @@ TEST(OrderedNodeList, Normal) {
...
@@ -56,8 +63,15 @@ TEST(OrderedNodeList, Normal) {
nodes
.
emplace_back
(
std
::
move
(
node
));
nodes
.
emplace_back
(
std
::
move
(
node
));
}
}
// Insert
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
pool
.
Insert
(
node
.
get
(),
op
.
get
());
pool
.
Insert
(
node
.
get
());
}
// Has/size
ASSERT_EQ
(
pool
.
size
(),
shapes
.
size
());
for
(
auto
&
node
:
nodes
)
{
ASSERT_TRUE
(
pool
.
Has
(
node
.
get
()));
}
}
// assert its order and interface.
// assert its order and interface.
...
@@ -66,14 +80,14 @@ TEST(OrderedNodeList, Normal) {
...
@@ -66,14 +80,14 @@ TEST(OrderedNodeList, Normal) {
std
::
cout
<<
pool
.
ToString
()
<<
std
::
endl
;
std
::
cout
<<
pool
.
ToString
()
<<
std
::
endl
;
ASSERT_EQ
(
pool
.
size
(),
static_cast
<
size_t
>
(
COUNT
-
1
));
ASSERT_EQ
(
pool
.
size
(),
static_cast
<
size_t
>
(
COUNT
-
1
));
ASSERT_EQ
(
pool
.
Get
Index
(
nodes
.
back
().
get
()),
0
);
ASSERT_EQ
(
pool
.
Get
NodeIndexInPool
(
nodes
.
back
().
get
()),
0
);
{
{
auto
v1
=
block_desc
->
Var
(
"11"
);
auto
v1
=
block_desc
->
Var
(
"11"
);
v1
->
SetShape
({
-
1
,
256
,
56
,
56
});
v1
->
SetShape
({
-
1
,
256
,
56
,
56
});
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v1
);
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v1
);
node1
->
inputs
.
emplace_back
(
op
.
get
());
node1
->
inputs
.
emplace_back
(
op
.
get
());
auto
*
cache
=
pool
.
NodeMatch
(
node1
.
get
());
auto
*
cache
=
pool
.
FindBestFitNode
(
node1
.
get
());
ASSERT_EQ
(
cache
,
nullptr
);
ASSERT_EQ
(
cache
,
nullptr
);
}
}
{
{
...
@@ -81,16 +95,401 @@ TEST(OrderedNodeList, Normal) {
...
@@ -81,16 +95,401 @@ TEST(OrderedNodeList, Normal) {
v2
->
SetShape
({
-
1
,
2
,
5
});
v2
->
SetShape
({
-
1
,
2
,
5
});
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v2
);
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v2
);
node1
->
inputs
.
emplace_back
(
op
.
get
());
node1
->
inputs
.
emplace_back
(
op
.
get
());
auto
*
cache
=
pool
.
NodeMatch
(
node1
.
get
());
auto
*
cache
=
pool
.
FindBestFitNode
(
node1
.
get
());
ASSERT_EQ
(
pool
.
Get
Index
(
cache
),
2
);
// match 6:[-1,2,5]
ASSERT_EQ
(
pool
.
Get
NodeIndexInPool
(
cache
),
2
);
// match 6:[-1,2,5]
}
}
{
{
auto
v3
=
block_desc
->
Var
(
"13"
);
auto
v3
=
block_desc
->
Var
(
"13"
);
v3
->
SetShape
({
2
,
5
});
v3
->
SetShape
({
2
,
5
});
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v3
);
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v3
);
node1
->
inputs
.
emplace_back
(
op
.
get
());
node1
->
inputs
.
emplace_back
(
op
.
get
());
auto
*
cache
=
pool
.
NodeMatch
(
node1
.
get
());
auto
*
cache
=
pool
.
FindBestFitNode
(
node1
.
get
());
ASSERT_EQ
(
pool
.
GetIndex
(
cache
),
5
);
// match 4:[5,2]
ASSERT_EQ
(
pool
.
GetNodeIndexInPool
(
cache
),
5
);
// match 4:[5,2]
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
assign
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
AssignOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
dummy
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace
paddle
{
namespace
framework
{
namespace
details
{
inline
static
ProgramDesc
FillProgramDesc
()
{
ProgramDesc
prog
;
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"d"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"e"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"c"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"d"
});
op
->
SetOutput
(
"Out"
,
{
"e"
});
}
return
prog
;
}
TEST
(
CFGGraph
,
IRGraph
)
{
// prepare ir graph
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
ControlFlowGraph
cfg
(
graph
);
cfg
.
LiveVariableAnalysis
();
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
0
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
0
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
1
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
1
])));
// test sum op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
2
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
2
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
3
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
3
])));
}
// 1. normal test
TEST
(
SortOpLikeDescOrder
,
NormalTest
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 2. remove some op_desc
TEST
(
SortOpLikeDescOrder
,
RemoveOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
graph
.
Nodes
();
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
outputs
.
back
()
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
graph
.
RemoveNode
(
e
);
// other node keeps the same order
auto
remain_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
remain_nodes
.
size
();
++
i
)
{
auto
node
=
remain_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 3. add some op_desc
TEST
(
SortOpLikeDescOrder
,
AddOpDesc
)
{
auto
prog
=
FillProgramDesc
();
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
ir
::
Graph
graph
(
prog
);
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
op_descs
.
insert
(
op_descs
.
begin
()
+
4
,
op
);
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 4. add and delete some op_desc
TEST
(
SortOpLikeDescOrder
,
AddAndDeleteOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// remove sum node
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
auto
nodes
=
graph
.
Nodes
();
for
(
auto
node
:
nodes
)
{
if
(
node
->
Name
()
==
"sum"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
c
->
outputs
.
begin
(),
c
->
outputs
.
end
(),
found_node
);
ir
::
Node
*
pending_op
=
found_node
->
outputs
[
0
]
->
outputs
[
0
];
graph
.
RemoveNode
(
e
);
graph
.
RemoveNode
(
pending_op
);
graph
.
RemoveNode
(
found_node
);
}
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
insert
(
op_descs
.
begin
()
+
2
,
op
);
// check the order
auto
mynodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
mynodes
.
size
();
++
i
)
{
auto
node
=
mynodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 5. add and replace some op_desc inplace.
TEST
(
SortOpLikeDescOrder
,
AddAndReplaceOpDescInplace
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
emplace_back
(
op
);
// replace op_desc inplace
auto
nodes
=
graph
.
Nodes
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
&&
node
->
Name
()
==
"assign"
)
{
if
(
node
->
outputs
.
size
()
==
1
&&
node
->
outputs
[
0
]
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
e
->
inputs
.
begin
(),
e
->
inputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
}
op_descs
.
erase
(
op_descs
.
begin
()
+
3
);
auto
replace_op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
replace_op
->
SetType
(
"sum"
);
replace_op
->
SetInput
(
"X"
,
{
"d"
,
"d1"
});
replace_op
->
SetOutput
(
"Out"
,
{
"e"
});
{
ir
::
Node
*
sum2
=
graph
.
CreateOpNode
(
replace_op
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
d1
=
find_node_in_graph
(
"d1"
);
sum2
->
inputs
.
emplace_back
(
d
);
sum2
->
inputs
.
emplace_back
(
d1
);
sum2
->
outputs
.
emplace_back
(
e
);
e
->
inputs
.
emplace_back
(
sum2
);
d
->
outputs
.
emplace_back
(
sum2
);
d1
->
outputs
.
emplace_back
(
sum2
);
}
op_descs
.
emplace_back
(
replace_op
);
// compare op order
auto
graph_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
graph_nodes
.
size
();
++
i
)
{
auto
node
=
graph_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
}
}
...
...
paddle/fluid/framework/details/memory_optimize_pass.cc
浏览文件 @
381f2015
...
@@ -43,11 +43,6 @@ namespace paddle {
...
@@ -43,11 +43,6 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
inline
bool
IsSameDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
nodes
=
graph
->
Nodes
();
auto
nodes
=
graph
->
Nodes
();
...
@@ -77,7 +72,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -77,7 +72,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
if
(
!
NodeCanReused
(
var
)
||
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
||
if
(
!
NodeCanReused
(
var
)
||
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
||
skip_set_
.
count
(
var
->
Name
()))
skip_set_
.
count
(
var
->
Name
()))
continue
;
continue
;
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
ir
::
Node
*
cache
=
pool_
.
FindBestFitNode
(
var
);
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
...
@@ -95,11 +90,12 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -95,11 +90,12 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
<<
"replace it again. Skip this candidate."
;
<<
"replace it again. Skip this candidate."
;
continue
;
continue
;
int
node_idx_in_pool
=
pool_
.
Get
Index
(
cache
);
int
node_idx_in_pool
=
pool_
.
Get
NodeIndexInPool
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
std
::
to_string
(
reuse_id
++
),
DebugString
(
var
),
DebugString
(
cache
),
std
::
to_string
(
reuse_id
++
),
DebugString
(
var
),
DebugString
(
cache
),
node_idx_in_pool
,
static_cast
<
int
>
(
pool_
.
size
()));
node_idx_in_pool
,
static_cast
<
int
>
(
pool_
.
size
()));
// update CFG Graph on the fly.
// update CFG Graph on the fly.
// reused var maybe re-fill into the pool
// reused var maybe re-fill into the pool
cfg_
->
RenameVarInCFGGraph
(
var
->
Name
(),
cache
->
Name
(),
idx
);
cfg_
->
RenameVarInCFGGraph
(
var
->
Name
(),
cache
->
Name
(),
idx
);
...
@@ -112,6 +108,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -112,6 +108,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
pool_
.
Erase
(
cache
);
pool_
.
Erase
(
cache
);
}
}
// fill the pool
// fill the pool
std
::
unordered_set
<
std
::
string
>
unlived_vars
;
std
::
unordered_set
<
std
::
string
>
unlived_vars
;
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
...
@@ -120,36 +117,15 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -120,36 +117,15 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
}
}
for
(
auto
var
:
unlived_vars
)
{
for
(
auto
var
:
unlived_vars
)
{
ir
::
Node
*
var_node
=
cfg_
->
GetNode
FromVar
Name
(
var
,
op
);
ir
::
Node
*
var_node
=
cfg_
->
GetNode
By
Name
(
var
,
op
);
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
pool_
.
Insert
(
var_node
,
op
);
pool_
.
Insert
(
var_node
);
}
}
}
}
}
}
}
}
graph
->
ResolveHazard
(
var_nodes_
);
graph
->
ResolveHazard
(
var_nodes_
);
// For early delete pass. use GraphNodePool load the unlived vars.
// 1. find all deps op for each unlived var in memory pool.
for
(
auto
&
op
:
graph
->
Nodes
())
{
for
(
auto
&
var
:
op
->
inputs
)
{
if
(
pool_
.
Has
(
var
))
{
pool_
.
Insert
(
var
,
op
);
}
}
}
// 2. convert ir node based memory pool to graph node
// because Node* maybe released bettwen passes.
auto
&
graph_pool
=
graph
->
Get
<
GraphNodePool
>
(
kGraphNodePool
);
for
(
auto
it
=
pool_
.
begin
();
it
!=
pool_
.
end
();
++
it
)
{
std
::
unordered_set
<
OpDesc
*>
descs
;
for
(
auto
&
op
:
it
->
second
)
{
PADDLE_ENFORCE
(
op
->
IsOp
());
descs
.
insert
(
op
->
Op
());
}
graph_pool
.
push_back
(
std
::
make_pair
(
it
->
first
->
Name
(),
descs
));
}
return
graph
;
return
graph
;
}
}
...
@@ -198,12 +174,12 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
...
@@ -198,12 +174,12 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
PADDLE_ENFORCE
(
sub_op
!=
nullptr
);
PADDLE_ENFORCE
(
sub_op
!=
nullptr
);
for
(
auto
*
var
:
sub_op
->
outputs
)
{
for
(
auto
*
var
:
sub_op
->
outputs
)
{
if
(
NodeCanReused
(
var
))
{
if
(
NodeCanReused
(
var
))
{
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
ir
::
Node
*
cache
=
pool_
.
FindBestFitNode
(
var
);
if
(
cache
!=
nullptr
)
{
if
(
cache
!=
nullptr
)
{
if
(
var
->
Var
()
->
GetDataType
()
!=
cache
->
Var
()
->
GetDataType
())
{
if
(
var
->
Var
()
->
GetDataType
()
!=
cache
->
Var
()
->
GetDataType
())
{
continue
;
continue
;
}
}
int
node_idx_in_pool
=
pool_
.
Get
Index
(
cache
);
int
node_idx_in_pool
=
pool_
.
Get
NodeIndexInPool
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
std
::
to_string
(
sub_reuse_id
++
),
DebugString
(
var
),
std
::
to_string
(
sub_reuse_id
++
),
DebugString
(
var
),
...
@@ -342,267 +318,10 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
...
@@ -342,267 +318,10 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
var_nodes_
.
at
(
var
).
clear
();
var_nodes_
.
at
(
var
).
clear
();
}
}
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
"Graph has no attribute of kAllOpDescs."
);
// 1. get op desc order
auto
&
op_descs
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
);
// 2. topology sort order
auto
nodes
=
graph
.
Nodes
();
std
::
deque
<
ir
::
Node
*>
ops
;
FilterVariables
(
nodes
,
[
&
](
ir
::
Node
*
op
)
{
if
(
op
->
IsOp
()
&&
op
->
Op
()
!=
nullptr
)
{
ops
.
emplace_back
(
op
);
}
});
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
list
<
ir
::
Node
*>
ready_ops
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
for
(
auto
*
op
:
ops
)
{
std
::
unordered_set
<
ir
::
Node
*>
preceding_op
;
for
(
auto
*
in
:
op
->
inputs
)
{
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
());
preceding_op
.
emplace
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
emplace
(
op
);
}
op_deps
[
op
]
=
preceding_op
.
size
();
if
(
preceding_op
.
empty
())
{
ready_ops
.
emplace_back
(
op
);
}
}
// 3. generated op list based desc order and the topology order
std
::
vector
<
ir
::
Node
*>
ret
;
std
::
list
<
OpDesc
*>
op_descs_list
(
op_descs
.
begin
(),
op_descs
.
end
());
auto
update_by_found_node
=
[
&
](
ir
::
Node
*
found_node
)
{
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
[
pending_op
]
==
0
)
{
ready_ops
.
emplace_back
(
pending_op
);
}
}
ready_ops
.
remove
(
found_node
);
ret
.
emplace_back
(
found_node
);
};
while
(
!
ready_ops
.
empty
())
{
bool
all_of_ready_op_unmatched
=
true
;
for
(
auto
it
=
op_descs_list
.
begin
();
it
!=
op_descs_list
.
end
();)
{
auto
op_desc
=
*
it
;
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
op
:
ready_ops
)
{
if
(
IsSameDesc
(
op
->
Op
(),
op_desc
))
{
found_node
=
op
;
break
;
}
}
// 3.1 op desc deleted by other pass
if
(
found_node
==
nullptr
)
{
++
it
;
continue
;
}
else
{
all_of_ready_op_unmatched
=
false
;
it
=
op_descs_list
.
erase
(
it
);
}
update_by_found_node
(
found_node
);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std
::
list
<
ir
::
Node
*>
prev_ready_ops
(
ready_ops
);
if
(
all_of_ready_op_unmatched
)
{
for
(
auto
op
:
prev_ready_ops
)
{
update_by_found_node
(
op
);
}
}
}
PADDLE_ENFORCE
(
std
::
all_of
(
op_deps
.
begin
(),
op_deps
.
end
(),
[
&
](
const
std
::
pair
<
ir
::
Node
*
,
size_t
>&
p
)
{
return
p
.
second
==
0
;
}));
return
ret
;
}
ControlFlowGraph
::
ControlFlowGraph
(
const
ir
::
Graph
&
graph
)
{
ops_
=
SortOpLikeDescOrder
(
graph
);
ConnectNodes
();
}
void
ControlFlowGraph
::
BuildCFGGraph
()
{
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for
(
ir
::
Node
*
op
:
ops_
)
{
for
(
auto
&
input_var
:
op
->
inputs
)
{
if
(
!
input_var
->
inputs
.
empty
())
{
PADDLE_ENFORCE
(
input_var
->
inputs
.
size
()
==
1
&&
input_var
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
auto
*
pred_op
=
input_var
->
inputs
[
0
];
if
(
pred_op
->
Op
()
!=
nullptr
)
{
predecessors_
[
op
].
insert
(
pred_op
);
successors_
[
pred_op
].
insert
(
op
);
}
}
if
(
input_var
->
IsVar
()
&&
!
input_var
->
IsCtrlVar
())
{
uses_
[
op
].
insert
(
input_var
->
Name
());
}
}
for
(
auto
&
output_var
:
op
->
outputs
)
{
// output var may be used by many op
for
(
auto
*
succ_op
:
output_var
->
outputs
)
{
if
(
succ_op
->
Op
()
!=
nullptr
)
{
successors_
[
op
].
insert
(
succ_op
);
predecessors_
[
succ_op
].
insert
(
op
);
}
}
if
(
output_var
->
IsVar
()
&&
!
output_var
->
IsCtrlVar
())
{
defs_
[
op
].
insert
(
output_var
->
Name
());
}
}
}
}
void
ControlFlowGraph
::
ConnectNodes
()
{
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
auto
&
op
=
ops_
[
i
];
try
{
auto
&
next_op
=
ops_
.
at
(
i
+
1
);
successors_
[
op
].
insert
(
next_op
);
predecessors_
[
next_op
].
insert
(
op
);
}
catch
(...)
{
// do nothing
}
FilterVariables
(
op
->
inputs
,
[
&
](
ir
::
Node
*
var
)
{
uses_
[
op
].
emplace
(
var
->
Name
());
});
FilterVariables
(
op
->
outputs
,
[
&
](
ir
::
Node
*
var
)
{
defs_
[
op
].
emplace
(
var
->
Name
());
});
}
}
void
ControlFlowGraph
::
LiveVariableAnalysis
()
{
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std
::
list
<
ir
::
Node
*>
work_list
(
ops_
.
rbegin
(),
ops_
.
rend
());
while
(
!
work_list
.
empty
())
{
ir
::
Node
*
op
=
work_list
.
front
();
work_list
.
pop_front
();
// get the live_in calculated before. Empty if first.
auto
prev_live_in
=
std
::
move
(
live_in_
[
op
]);
for
(
auto
&
s
:
successors_
[
op
])
{
for
(
auto
&
var
:
live_in_
[
s
])
{
live_out_
[
op
].
insert
(
var
);
}
}
for
(
auto
&
var
:
uses_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
live_out_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
defs_
[
op
])
{
live_in_
[
op
].
erase
(
var
);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if
(
live_in_
[
op
]
!=
prev_live_in
)
{
for
(
auto
&
pre
:
predecessors_
[
op
])
{
work_list
.
push_back
(
pre
);
}
}
}
}
void
ControlFlowGraph
::
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
)
{
// update graph from begin idx to the end
for
(
size_t
i
=
begin_idx
;
i
!=
ops_
.
size
();
++
i
)
{
auto
*
op
=
ops_
[
i
];
if
(
uses_
[
op
].
find
(
old_node
)
!=
uses_
[
op
].
end
())
{
uses_
[
op
].
erase
(
old_node
);
uses_
[
op
].
insert
(
new_node
);
}
if
(
defs_
[
op
].
find
(
old_node
)
!=
defs_
[
op
].
end
())
{
defs_
[
op
].
erase
(
old_node
);
defs_
[
op
].
insert
(
new_node
);
}
if
(
live_in_
[
op
].
find
(
old_node
)
!=
live_in_
[
op
].
end
())
{
live_in_
[
op
].
erase
(
old_node
);
live_in_
[
op
].
insert
(
new_node
);
}
if
(
live_out_
[
op
].
find
(
old_node
)
!=
live_out_
[
op
].
end
())
{
live_out_
[
op
].
erase
(
old_node
);
live_out_
[
op
].
insert
(
new_node
);
}
}
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveIn
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_in_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_in_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_in, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveOut
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_out_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_out_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
Use
(
ir
::
Node
*
op
)
const
{
auto
it
=
uses_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
uses_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
vector
<
ir
::
Node
*>
ControlFlowGraph
::
Ops
()
const
{
return
ops_
;
}
std
::
vector
<
ir
::
Node
*>&
ControlFlowGraph
::
Ops
()
{
return
ops_
;
}
ir
::
Node
*
ControlFlowGraph
::
GetNodeFromVarName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
{
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ops_
)
{
if
(
node
==
op
)
break
;
for
(
auto
&
output
:
node
->
outputs
)
{
if
(
output
->
Name
()
==
name
)
{
found_node
=
output
;
}
}
}
return
found_node
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
memory_optimize_pass
,
REGISTER_PASS
(
memory_optimize_pass
,
paddle
::
framework
::
details
::
MemoryOptimizePass
)
paddle
::
framework
::
details
::
MemoryOptimizePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
paddle/fluid/framework/details/memory_optimize_pass.h
浏览文件 @
381f2015
...
@@ -32,20 +32,15 @@
...
@@ -32,20 +32,15 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
class
ControlFlowGraph
;
class
MemoryOptimizePass
:
public
ir
::
Pass
{
class
MemoryOptimizePass
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
private:
// fill the variable map(var_nodes) by version.
// fill the variable map(var_nodes) by version.
void
InitSSAGraphNodes
()
const
;
void
InitSSAGraphNodes
()
const
;
private:
// update program descs
// update program descs
void
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
void
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
size_t
idx
)
const
;
const
std
::
string
&
cache_var
,
size_t
idx
)
const
;
...
@@ -62,7 +57,7 @@ class MemoryOptimizePass : public ir::Pass {
...
@@ -62,7 +57,7 @@ class MemoryOptimizePass : public ir::Pass {
private:
private:
// Reuse Node Pool, Owned.
// Reuse Node Pool, Owned.
mutable
Ordered
NodeLis
t
pool_
;
mutable
Ordered
Se
t
pool_
;
// controlflow Graph
// controlflow Graph
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
// skip set
// skip set
...
@@ -71,45 +66,6 @@ class MemoryOptimizePass : public ir::Pass {
...
@@ -71,45 +66,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
};
};
class
ControlFlowGraph
{
public:
ControlFlowGraph
()
=
default
;
// For IR Graph in parallelexecutor
explicit
ControlFlowGraph
(
const
ir
::
Graph
&
graph
);
void
LiveVariableAnalysis
();
void
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
);
const
std
::
set
<
std
::
string
>
LiveIn
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
LiveOut
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
Use
(
ir
::
Node
*
op
)
const
;
const
std
::
vector
<
ir
::
Node
*>
Ops
()
const
;
std
::
vector
<
ir
::
Node
*>&
Ops
();
// for ssa-graph nodes
ir
::
Node
*
GetNodeFromVarName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
;
private:
void
BuildCFGGraph
();
void
ConnectNodes
();
using
NodeListMap
=
std
::
unordered_map
<
ir
::
Node
*
,
std
::
set
<
ir
::
Node
*>>
;
using
VarSetMap
=
std
::
map
<
ir
::
Node
*
,
std
::
set
<
std
::
string
>>
;
// successors ops use the output variables.
NodeListMap
successors_
;
// predecessors ops generated input variables.
NodeListMap
predecessors_
;
// variables lived before run current op.
VarSetMap
live_in_
;
// variables lived after run current op.
VarSetMap
live_out_
;
VarSetMap
uses_
;
// op inputs
VarSetMap
defs_
;
// op outputs
std
::
vector
<
ir
::
Node
*>
ops_
;
// op sequence by topology sort
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/memory_optimize_pass_test.cc
已删除
100644 → 0
浏览文件 @
6492ea9c
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
assign
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
AssignOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
dummy
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
inline
bool
IsSameDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
inline
static
ProgramDesc
FillProgramDesc
()
{
ProgramDesc
prog
;
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"d"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"e"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"c"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"d"
});
op
->
SetOutput
(
"Out"
,
{
"e"
});
}
return
prog
;
}
TEST
(
CFGGraph
,
IRGraph
)
{
// prepare ir graph
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
ControlFlowGraph
cfg
(
graph
);
cfg
.
LiveVariableAnalysis
();
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
0
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
0
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
1
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
1
])));
// test sum op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
2
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
2
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
3
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
3
])));
}
// 1. normal test
TEST
(
SortOpLikeDescOrder
,
NormalTest
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 2. remove some op_desc
TEST
(
SortOpLikeDescOrder
,
RemoveOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
graph
.
Nodes
();
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
outputs
.
back
()
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
graph
.
RemoveNode
(
e
);
// other node keeps the same order
auto
remain_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
remain_nodes
.
size
();
++
i
)
{
auto
node
=
remain_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 3. add some op_desc
TEST
(
SortOpLikeDescOrder
,
AddOpDesc
)
{
auto
prog
=
FillProgramDesc
();
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
ir
::
Graph
graph
(
prog
);
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
op_descs
.
insert
(
op_descs
.
begin
()
+
4
,
op
);
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 4. add and delete some op_desc
TEST
(
SortOpLikeDescOrder
,
AddAndDeleteOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// remove sum node
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
auto
nodes
=
graph
.
Nodes
();
for
(
auto
node
:
nodes
)
{
if
(
node
->
Name
()
==
"sum"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
c
->
outputs
.
begin
(),
c
->
outputs
.
end
(),
found_node
);
ir
::
Node
*
pending_op
=
found_node
->
outputs
[
0
]
->
outputs
[
0
];
graph
.
RemoveNode
(
e
);
graph
.
RemoveNode
(
pending_op
);
graph
.
RemoveNode
(
found_node
);
}
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
insert
(
op_descs
.
begin
()
+
2
,
op
);
// check the order
auto
mynodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
mynodes
.
size
();
++
i
)
{
auto
node
=
mynodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 5. add and replace some op_desc inplace.
TEST
(
SortOpLikeDescOrder
,
AddAndReplaceOpDescInplace
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
emplace_back
(
op
);
// replace op_desc inplace
auto
nodes
=
graph
.
Nodes
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
&&
node
->
Name
()
==
"assign"
)
{
if
(
node
->
outputs
.
size
()
==
1
&&
node
->
outputs
[
0
]
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
e
->
inputs
.
begin
(),
e
->
inputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
}
op_descs
.
erase
(
op_descs
.
begin
()
+
3
);
auto
replace_op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
replace_op
->
SetType
(
"sum"
);
replace_op
->
SetInput
(
"X"
,
{
"d"
,
"d1"
});
replace_op
->
SetOutput
(
"Out"
,
{
"e"
});
{
ir
::
Node
*
sum2
=
graph
.
CreateOpNode
(
replace_op
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
d1
=
find_node_in_graph
(
"d1"
);
sum2
->
inputs
.
emplace_back
(
d
);
sum2
->
inputs
.
emplace_back
(
d1
);
sum2
->
outputs
.
emplace_back
(
e
);
e
->
inputs
.
emplace_back
(
sum2
);
d
->
outputs
.
emplace_back
(
sum2
);
d1
->
outputs
.
emplace_back
(
sum2
);
}
op_descs
.
emplace_back
(
replace_op
);
// compare op order
auto
graph_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
graph_nodes
.
size
();
++
i
)
{
auto
node
=
graph_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/sequential_execution_pass.cc
浏览文件 @
381f2015
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/framework/details/sequential_execution_pass.h
浏览文件 @
381f2015
...
@@ -21,8 +21,6 @@ namespace paddle {
...
@@ -21,8 +21,6 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
class
SequentialExecutionPass
:
public
ir
::
Pass
{
class
SequentialExecutionPass
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
...
...
paddle/fluid/framework/inplace_op_inference.h
浏览文件 @
381f2015
...
@@ -69,7 +69,7 @@ class InplaceInToOut : public InplaceOpInference {
...
@@ -69,7 +69,7 @@ class InplaceInToOut : public InplaceOpInference {
bool
TryInplaceInputOutput
(
const
VarDesc
&
in
,
const
VarDesc
&
out
)
const
{
bool
TryInplaceInputOutput
(
const
VarDesc
&
in
,
const
VarDesc
&
out
)
const
{
return
in
.
Name
()
!=
out
.
Name
()
&&
details
::
NodeCanReused
(
in
)
&&
return
in
.
Name
()
!=
out
.
Name
()
&&
details
::
NodeCanReused
(
in
)
&&
details
::
NodeCanReused
(
out
)
&&
details
::
NodeCanReused
(
out
)
&&
details
::
NodeSize
InBytes
(
out
)
<=
details
::
NodeSizeInBytes
(
in
);
details
::
NodeSize
(
out
)
<=
details
::
NodeSize
(
in
);
}
}
};
};
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
381f2015
...
@@ -171,14 +171,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
...
@@ -171,14 +171,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
if
(
build_strategy_
.
memory_early_delete_
)
{
auto
early_delete_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"memory_early_delete_pass"
);
early_delete_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
graph
=
early_delete_pass
->
Apply
(
std
::
move
(
graph
));
}
VLOG
(
10
)
<<
"MemoryEarlyDeletePass Applied."
;
}
}
return
graph
;
return
graph
;
...
@@ -288,6 +280,8 @@ ParallelExecutor::ParallelExecutor(
...
@@ -288,6 +280,8 @@ ParallelExecutor::ParallelExecutor(
graphs
.
push_back
(
std
::
move
(
graph
));
graphs
.
push_back
(
std
::
move
(
graph
));
#endif
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
auto
max_memory_size
=
GetEagerDeletionThreshold
();
VLOG
(
10
)
<<
"Eager Deletion Threshold "
<<
static_cast
<
float
>
(
max_memory_size
)
/
(
1
<<
30
);
if
(
max_memory_size
>=
0
)
{
if
(
max_memory_size
>=
0
)
{
for
(
size_t
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
graphs
[
i
]
=
member_
->
PrepareGCAndRefCnts
(
graphs
[
i
]
=
member_
->
PrepareGCAndRefCnts
(
...
@@ -506,6 +500,5 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -506,6 +500,5 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
memory_early_delete_pass
);
USE_PASS
(
reference_count_pass
);
USE_PASS
(
reference_count_pass
);
USE_PASS
(
eager_deletion_pass
);
USE_PASS
(
eager_deletion_pass
);
paddle/fluid/framework/scope.cc
浏览文件 @
381f2015
...
@@ -22,11 +22,7 @@ limitations under the License. */
...
@@ -22,11 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool
(
benchmark
,
false
,
DECLARE_bool
(
benchmark
);
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
DEFINE_bool
(
DEFINE_bool
(
eager_delete_scope
,
true
,
eager_delete_scope
,
true
,
...
...
paddle/fluid/memory/allocation/legacy_allocator.cc
浏览文件 @
381f2015
...
@@ -36,6 +36,7 @@ DEFINE_bool(init_allocated_mem, false,
...
@@ -36,6 +36,7 @@ DEFINE_bool(init_allocated_mem, false,
"that initializing the allocated memory with a small value "
"that initializing the allocated memory with a small value "
"during unit testing."
);
"during unit testing."
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
DECLARE_bool
(
benchmark
);
namespace
paddle
{
namespace
paddle
{
namespace
memory
{
namespace
memory
{
...
@@ -198,7 +199,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
...
@@ -198,7 +199,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
platform
::
SetDeviceId
(
cur_dev
);
platform
::
SetDeviceId
(
cur_dev
);
}
else
{
}
else
{
if
(
VLOG_IS_ON
(
3
)
)
{
if
(
FLAGS_benchmark
)
{
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
}
}
if
(
FLAGS_init_allocated_mem
)
{
if
(
FLAGS_init_allocated_mem
)
{
...
@@ -216,7 +217,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
...
@@ -216,7 +217,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
size_t
size
)
{
size_t
size
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
if
(
VLOG_IS_ON
(
3
)
)
{
if
(
FLAGS_benchmark
)
{
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
);
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
);
}
}
#else
#else
...
...
paddle/fluid/platform/place.cc
浏览文件 @
381f2015
...
@@ -14,6 +14,12 @@ limitations under the License. */
...
@@ -14,6 +14,12 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
DEFINE_bool
(
benchmark
,
false
,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
381f2015
...
@@ -1099,10 +1099,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1099,10 +1099,6 @@ All parameter, weight, gradient are variables in Paddle.
"is_distribution"
,
"is_distribution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
is_distribution_
=
b
;
})
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
is_distribution_
=
b
;
})
.
def_property
(
"memory_early_delete"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_early_delete_
=
b
;
})
.
def_property
(
.
def_property
(
"enable_inplace"
,
"enable_inplace"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
381f2015
...
@@ -148,6 +148,7 @@ class ParallelExecutor(object):
...
@@ -148,6 +148,7 @@ class ParallelExecutor(object):
else
framework
.
default_main_program
()
else
framework
.
default_main_program
()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# if turn on python memory optimize, turn off the inplace_pass.
if
build_strategy
.
enable_inplace
is
None
:
build_strategy
.
enable_inplace
=
False
if
main
.
_is_mem_optimized
else
True
build_strategy
.
enable_inplace
=
False
if
main
.
_is_mem_optimized
else
True
scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录