Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
381f2015
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
381f2015
编写于
2月 12, 2019
作者:
D
dzhwinter
提交者:
GitHub
2月 12, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15665 from dzhwinter/experiment/refactor_memory
refactor optimize pass.
上级
6492ea9c
04e9776a
变更
23
显示空白变更内容
内联
并排
Showing
23 changed file
with
842 addition
and
1027 deletion
+842
-1027
cmake/flags.cmake
cmake/flags.cmake
+2
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-6
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+0
-2
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+0
-3
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+6
-7
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+8
-7
paddle/fluid/framework/details/memory_early_delete_pass.cc
paddle/fluid/framework/details/memory_early_delete_pass.cc
+0
-117
paddle/fluid/framework/details/memory_early_delete_pass.h
paddle/fluid/framework/details/memory_early_delete_pass.h
+0
-32
paddle/fluid/framework/details/memory_optimize_helper.cc
paddle/fluid/framework/details/memory_optimize_helper.cc
+306
-30
paddle/fluid/framework/details/memory_optimize_helper.h
paddle/fluid/framework/details/memory_optimize_helper.h
+83
-36
paddle/fluid/framework/details/memory_optimize_helper_test.cc
...le/fluid/framework/details/memory_optimize_helper_test.cc
+408
-9
paddle/fluid/framework/details/memory_optimize_pass.cc
paddle/fluid/framework/details/memory_optimize_pass.cc
+8
-289
paddle/fluid/framework/details/memory_optimize_pass.h
paddle/fluid/framework/details/memory_optimize_pass.h
+3
-47
paddle/fluid/framework/details/memory_optimize_pass_test.cc
paddle/fluid/framework/details/memory_optimize_pass_test.cc
+0
-417
paddle/fluid/framework/details/sequential_execution_pass.cc
paddle/fluid/framework/details/sequential_execution_pass.cc
+1
-0
paddle/fluid/framework/details/sequential_execution_pass.h
paddle/fluid/framework/details/sequential_execution_pass.h
+0
-2
paddle/fluid/framework/inplace_op_inference.h
paddle/fluid/framework/inplace_op_inference.h
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+2
-9
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+1
-5
paddle/fluid/memory/allocation/legacy_allocator.cc
paddle/fluid/memory/allocation/legacy_allocator.cc
+3
-2
paddle/fluid/platform/place.cc
paddle/fluid/platform/place.cc
+6
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-4
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+2
-1
未找到文件。
cmake/flags.cmake
浏览文件 @
381f2015
...
@@ -27,6 +27,7 @@ endfunction()
...
@@ -27,6 +27,7 @@ endfunction()
CheckCompilerCXX11Flag
()
CheckCompilerCXX11Flag
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-m64"
)
# safe_set_flag
# safe_set_flag
#
#
# Set a compile flag only if compiler is support
# Set a compile flag only if compiler is support
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
381f2015
...
@@ -54,8 +54,6 @@ cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph grap
...
@@ -54,8 +54,6 @@ cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph grap
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
...
@@ -67,13 +65,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
...
@@ -67,13 +65,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass
memory_early_delete_pass
inplace_op_pass
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
endif
()
endif
()
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph
)
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry
)
cc_test
(
memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
381f2015
...
@@ -206,8 +206,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -206,8 +206,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
all_op_descs
);
// take ownership
all_op_descs
);
// take ownership
graph
->
Set
<
GraphNodePool
>
(
kGraphNodePool
,
new
GraphNodePool
);
// take ownership
pass
->
Erase
(
kAllOpDescs
);
pass
->
Erase
(
kAllOpDescs
);
pass
->
SetNotOwned
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
all_op_descs
);
pass
->
SetNotOwned
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
all_op_descs
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
381f2015
...
@@ -77,9 +77,6 @@ struct BuildStrategy {
...
@@ -77,9 +77,6 @@ struct BuildStrategy {
bool
fuse_relu_depthwise_conv_
{
false
};
bool
fuse_relu_depthwise_conv_
{
false
};
bool
memory_optimize_
{
false
};
bool
memory_optimize_
{
false
};
bool
memory_early_delete_
{
false
};
// TODO(dzhwinter):
// TODO(dzhwinter):
// make enable_inplace, memory_optimize_
// make enable_inplace, memory_optimize_
// memory_early_delete_ true by default
// memory_early_delete_ true by default
...
...
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
381f2015
...
@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
...
@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
}
}
}
}
const
SSANodePair
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
NodeSwapQueue
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
cache_var
,
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
var_desc
->
SetName
(
cache_var
);
SSANodePair
swap_nodes
;
NodeSwapQueue
swap_nodes
;
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
auto
*
op
=
view_
.
AllOps
()[
i
];
...
@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
...
@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
return
swap_nodes
;
return
swap_nodes
;
}
}
void
InplacePass
::
CommitModify
(
const
SSANodePair
&
swap_nodes
,
void
InplacePass
::
CommitModify
(
const
NodeSwapQueue
&
swap_nodes
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
swap_nodes
)
{
for
(
auto
&
pair
:
swap_nodes
)
{
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
...
@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes,
...
@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes,
}
}
}
}
void
InplacePass
::
WithdrawModify
(
const
SSANodePair
&
nodes
,
void
InplacePass
::
WithdrawModify
(
const
NodeSwapQueue
&
nodes
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
nodes
)
{
for
(
auto
&
pair
:
nodes
)
{
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
...
...
paddle/fluid/framework/details/inplace_op_pass.h
浏览文件 @
381f2015
...
@@ -56,7 +56,8 @@ class GraphView {
...
@@ -56,7 +56,8 @@ class GraphView {
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
};
};
typedef
std
::
vector
<
std
::
pair
<
ir
::
Node
*
,
ir
::
Node
*>>
SSANodePair
;
// swap pairs in sequence
typedef
std
::
vector
<
std
::
pair
<
ir
::
Node
*
,
ir
::
Node
*>>
NodeSwapQueue
;
class
InplacePass
:
public
ir
::
Pass
{
class
InplacePass
:
public
ir
::
Pass
{
public:
public:
InplacePass
();
InplacePass
();
...
@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass {
...
@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass {
void
InitSSAGraphNodes
()
const
;
void
InitSSAGraphNodes
()
const
;
private:
private:
const
SSANodePair
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
NodeSwapQueue
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
SSANodePair
&
,
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
NodeSwapQueue
&
,
ir
::
Graph
*
graph
)
const
;
void
WithdrawModify
(
const
SSANodePair
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
WithdrawModify
(
const
NodeSwapQueue
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
)
const
;
const
size_t
&
idx
)
const
;
...
...
paddle/fluid/framework/details/memory_early_delete_pass.cc
已删除
100644 → 0
浏览文件 @
6492ea9c
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
ComputationOpHandle
*
FindNextComputationOpHandle
(
VarHandle
*
var_in
)
{
std
::
queue
<
VarHandleBase
*>
queue
;
queue
.
push
(
var_in
);
do
{
auto
*
var
=
queue
.
front
();
queue
.
pop
();
for
(
auto
*
op
:
var
->
PendingOps
())
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetPlace
()
==
var_in
->
place
())
{
return
compute_op
;
}
for
(
auto
*
out_var
:
op
->
Outputs
())
{
queue
.
push
(
out_var
);
}
}
}
while
(
!
queue
.
empty
());
return
nullptr
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MemoryEarlyDeletePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
graph_pool
=
Get
<
GraphNodePool
>
(
kGraphNodePool
);
auto
&
gcs
=
Get
<
GarbageCollectorMap
>
(
kGarbageCollector
);
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
OpDesc
*>>
unlived_vars
;
unlived_vars
.
reserve
(
graph_pool
.
size
());
for
(
auto
&
pair
:
graph_pool
)
{
unlived_vars
.
insert
(
std
::
make_pair
(
pair
.
first
,
pair
.
second
));
}
auto
compare_and_insert_early_delete_op
=
[
&
](
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>&
vars
)
{
if
(
unlived_vars
.
empty
())
return
;
// unlived vars can be deleted after the last used op has finished.
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
const
auto
&
places
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
kAllPlaces
);
for
(
auto
&
var
:
vars
)
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var
);
auto
var_name
=
var
->
Node
()
->
Name
();
auto
&
var_place
=
var_handle
->
place
();
if
(
unlived_vars
.
count
(
var_name
)
==
0
)
continue
;
if
(
!
unlived_vars
[
var_name
].
empty
())
{
if
(
compute_op
!=
nullptr
&&
unlived_vars
[
var_name
].
count
(
compute_op
->
Node
()
->
Op
())
!=
0
)
{
unlived_vars
[
var_name
].
erase
(
compute_op
->
Node
()
->
Op
());
}
continue
;
}
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
()
||
var_handle
->
Node
()
->
IsCtrlVar
())
continue
;
// shameless copyed from reference count pass.
if
(
compute_op
==
nullptr
)
{
// use next computation op scope
compute_op
=
FindNextComputationOpHandle
(
var_handle
);
}
auto
*
early_delete_node
=
graph
->
CreateEmptyNode
(
"early_delete"
,
ir
::
Node
::
Type
::
kOperation
);
GarbageCollector
*
gc
=
gcs
.
at
(
places
[
compute_op
->
GetScopeIdx
()]).
get
();
auto
*
early_delete_handle
=
new
EarlyDeleteOpHandle
(
early_delete_node
,
compute_op
->
GetScope
(),
var_place
,
{
var_name
},
gc
);
if
(
compute_op
->
Outputs
().
empty
())
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
early_delete_handle
->
AddInput
(
compute_op
->
Outputs
().
front
());
VLOG
(
5
)
<<
"Add early delete op "
<<
var_name
<<
" to Operator"
<<
compute_op
->
Name
();
}
};
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
for
(
auto
&
op
:
all_ops
)
{
compare_and_insert_early_delete_op
(
op
,
op
->
Inputs
());
compare_and_insert_early_delete_op
(
op
,
op
->
Outputs
());
}
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
memory_early_delete_pass
,
paddle
::
framework
::
details
::
MemoryEarlyDeletePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/memory_early_delete_pass.h
已删除
100644 → 0
浏览文件 @
6492ea9c
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/early_delete_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
MemoryEarlyDeletePass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/memory_optimize_helper.cc
浏览文件 @
381f2015
...
@@ -13,17 +13,108 @@
...
@@ -13,17 +13,108 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <deque>
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
using
paddle
::
framework
::
VarDesc
;
size_t
NodeSizeInBytes
(
const
VarDesc
&
node
)
{
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
"Graph has no attribute of kAllOpDescs."
);
// 1. get op desc order
auto
&
op_descs
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
);
// 2. topology sort order
auto
nodes
=
graph
.
Nodes
();
std
::
deque
<
ir
::
Node
*>
ops
;
FilterVariables
(
nodes
,
[
&
](
ir
::
Node
*
op
)
{
if
(
op
->
IsOp
()
&&
op
->
Op
()
!=
nullptr
)
{
ops
.
emplace_back
(
op
);
}
});
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
list
<
ir
::
Node
*>
ready_ops
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
for
(
auto
*
op
:
ops
)
{
std
::
unordered_set
<
ir
::
Node
*>
preceding_op
;
for
(
auto
*
in
:
op
->
inputs
)
{
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
());
preceding_op
.
emplace
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
emplace
(
op
);
}
op_deps
[
op
]
=
preceding_op
.
size
();
if
(
preceding_op
.
empty
())
{
ready_ops
.
emplace_back
(
op
);
}
}
// 3. generated op list based desc order and the topology order
std
::
vector
<
ir
::
Node
*>
ret
;
std
::
list
<
OpDesc
*>
op_descs_list
(
op_descs
.
begin
(),
op_descs
.
end
());
auto
update_by_found_node
=
[
&
](
ir
::
Node
*
found_node
)
{
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
[
pending_op
]
==
0
)
{
ready_ops
.
emplace_back
(
pending_op
);
}
}
ready_ops
.
remove
(
found_node
);
ret
.
emplace_back
(
found_node
);
};
while
(
!
ready_ops
.
empty
())
{
bool
all_of_ready_op_unmatched
=
true
;
for
(
auto
it
=
op_descs_list
.
begin
();
it
!=
op_descs_list
.
end
();)
{
auto
op_desc
=
*
it
;
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
op
:
ready_ops
)
{
if
(
IsSameDesc
(
op
->
Op
(),
op_desc
))
{
found_node
=
op
;
break
;
}
}
// 3.1 op desc deleted by other pass
if
(
found_node
==
nullptr
)
{
++
it
;
continue
;
}
else
{
all_of_ready_op_unmatched
=
false
;
it
=
op_descs_list
.
erase
(
it
);
}
update_by_found_node
(
found_node
);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std
::
list
<
ir
::
Node
*>
prev_ready_ops
(
ready_ops
);
if
(
all_of_ready_op_unmatched
)
{
for
(
auto
op
:
prev_ready_ops
)
{
update_by_found_node
(
op
);
}
}
}
PADDLE_ENFORCE
(
std
::
all_of
(
op_deps
.
begin
(),
op_deps
.
end
(),
[
&
](
const
std
::
pair
<
ir
::
Node
*
,
size_t
>&
p
)
{
return
p
.
second
==
0
;
}));
return
ret
;
}
size_t
NodeSize
(
const
VarDesc
&
node
)
{
auto
shape
=
node
.
GetShape
();
auto
shape
=
node
.
GetShape
();
int
size
=
int
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
...
@@ -31,9 +122,9 @@ size_t NodeSizeInBytes(const VarDesc& node) {
...
@@ -31,9 +122,9 @@ size_t NodeSizeInBytes(const VarDesc& node) {
return
type_size
*
std
::
abs
(
size
);
return
type_size
*
std
::
abs
(
size
);
}
}
size_t
NodeSize
InBytes
(
ir
::
Node
*
n
)
{
size_t
NodeSize
(
ir
::
Node
*
n
)
{
auto
*
desc
=
FindVarDescInBlock
(
n
);
auto
*
desc
=
FindVarDescInBlock
(
n
);
return
NodeSize
InBytes
(
*
desc
);
return
NodeSize
(
*
desc
);
}
}
std
::
string
DebugStringImpl
(
VarDesc
*
var
)
{
std
::
string
DebugStringImpl
(
VarDesc
*
var
)
{
...
@@ -59,7 +150,6 @@ std::string DebugStringImpl(VarDesc* var) {
...
@@ -59,7 +150,6 @@ std::string DebugStringImpl(VarDesc* var) {
std
::
string
DebugString
(
ir
::
Node
*
var
)
{
std
::
string
DebugString
(
ir
::
Node
*
var
)
{
return
DebugStringImpl
(
FindVarDescInBlock
(
var
));
return
DebugStringImpl
(
FindVarDescInBlock
(
var
));
}
}
// return DebugString(var->Var()); }
// NOTE(dzh): based ir node, if a large node has been reused
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// by a small size node, then next time it appear in pool, it will
...
@@ -80,18 +170,17 @@ struct NodeComparator {
...
@@ -80,18 +170,17 @@ struct NodeComparator {
auto
rhs_shape
=
rhs_desc
->
GetShape
();
auto
rhs_shape
=
rhs_desc
->
GetShape
();
if
((
lhs_shape
[
0
]
==
-
1
&&
rhs_shape
[
0
]
==
-
1
)
||
if
((
lhs_shape
[
0
]
==
-
1
&&
rhs_shape
[
0
]
==
-
1
)
||
(
lhs_shape
[
0
]
!=
-
1
&&
rhs_shape
[
0
]
!=
-
1
))
{
(
lhs_shape
[
0
]
!=
-
1
&&
rhs_shape
[
0
]
!=
-
1
))
{
return
NodeSize
InBytes
(
lhs
)
<=
NodeSizeInBytes
(
rhs
);
return
NodeSize
(
lhs
)
<=
NodeSize
(
rhs
);
}
else
{
}
else
{
return
false
;
return
false
;
}
}
}
}
};
};
void
Ordered
NodeList
::
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
)
{
void
Ordered
Set
::
Insert
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
mark_table_
[
var
->
Name
()]
->
second
.
insert
(
op
);
mark_table_
[
var
->
Name
()]
->
emplace_back
(
var
);
return
;
return
;
}
}
...
@@ -99,14 +188,15 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
...
@@ -99,14 +188,15 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
auto
var_shape
=
var_desc
->
GetShape
();
auto
var_shape
=
var_desc
->
GetShape
();
int
batch_size
=
static_cast
<
int
>
(
var_shape
[
0
]);
int
batch_size
=
static_cast
<
int
>
(
var_shape
[
0
]);
NodeComparator
compare_node
;
NodeComparator
functor
;
Iter
it
=
nodes_
.
begin
();
Iter
it
=
nodes_
.
begin
();
while
(
it
!=
nodes_
.
end
())
{
while
(
it
!=
nodes_
.
end
())
{
auto
*
cache_desc
=
FindVarDescInBlock
(
it
->
first
);
auto
&
prev
=
it
->
front
();
auto
*
cache_desc
=
FindVarDescInBlock
(
prev
);
int
cache_batch_size
=
cache_desc
->
GetShape
()[
0
];
int
cache_batch_size
=
cache_desc
->
GetShape
()[
0
];
if
((
cache_batch_size
==
-
1
&&
batch_size
==
-
1
)
||
if
((
cache_batch_size
==
-
1
&&
batch_size
==
-
1
)
||
(
cache_batch_size
!=
-
1
&&
batch_size
!=
-
1
))
{
(
cache_batch_size
!=
-
1
&&
batch_size
!=
-
1
))
{
if
(
compare_node
(
it
->
first
,
var
))
{
if
(
functor
(
prev
,
var
))
{
++
it
;
++
it
;
}
else
{
}
else
{
break
;
break
;
...
@@ -118,62 +208,80 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
...
@@ -118,62 +208,80 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
}
}
}
}
it
=
it
=
nodes_
.
insert
(
it
,
{
var
});
nodes_
.
insert
(
it
,
std
::
make_pair
(
var
,
std
::
unordered_set
<
ir
::
Node
*>
{
op
}));
mark_table_
[
var
->
Name
()]
=
it
;
mark_table_
[
var
->
Name
()]
=
it
;
}
}
int
Ordered
NodeList
::
GetIndex
(
ir
::
Node
*
var
)
{
int
Ordered
Set
::
GetNodeIndexInPool
(
ir
::
Node
*
var
)
{
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
}
}
ir
::
Node
*
Ordered
NodeList
::
NodeMatch
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
Ordered
Set
::
FindBestFitNode
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
found_node
=
nullptr
;
ir
::
Node
*
found_node
=
nullptr
;
NodeComparator
compare_node
;
NodeComparator
functor
;
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
if
(
compare_node
(
var
,
it
->
first
))
{
auto
&
candidate
=
it
->
front
();
found_node
=
it
->
first
;
if
(
functor
(
var
,
candidate
))
{
found_node
=
candidate
;
break
;
break
;
}
}
}
}
return
found_node
;
return
found_node
;
}
}
void
OrderedNodeList
::
Erase
(
ir
::
Node
*
var
)
{
Erase
(
var
->
Name
());
}
bool
OrderedSet
::
Has
(
ir
::
Node
*
var
)
const
{
if
(
mark_table_
.
count
(
var
->
Name
()))
{
auto
&
node_in_samename
=
mark_table_
.
at
(
var
->
Name
());
auto
iter
=
std
::
find_if
(
node_in_samename
->
begin
(),
node_in_samename
->
end
(),
[
&
](
ir
::
Node
*
n
)
{
return
n
->
Name
()
==
var
->
Name
();
});
return
iter
!=
node_in_samename
->
end
();
}
return
false
;
}
void
Ordered
NodeList
::
Erase
(
const
std
::
string
&
var
)
{
void
Ordered
Set
::
Erase
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
));
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
->
Name
()
));
nodes_
.
erase
(
mark_table_
[
var
]);
nodes_
.
erase
(
mark_table_
[
var
->
Name
()
]);
mark_table_
.
erase
(
var
);
mark_table_
.
erase
(
var
->
Name
()
);
}
}
std
::
string
Ordered
NodeLis
t
::
ToString
()
const
{
std
::
string
Ordered
Se
t
::
ToString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
ss
<<
DebugString
(
it
->
first
)
<<
" "
;
for
(
auto
&
node
:
*
it
)
{
ss
<<
DebugString
(
node
)
<<
" "
;
}
}
}
return
ss
.
str
();
return
ss
.
str
();
}
}
bool
NodeCanReused
(
ir
::
Node
*
node
)
{
bool
NodeCanReused
(
ir
::
Node
*
node
)
{
// valid the node is a var node
if
(
node
==
nullptr
||
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
if
(
node
==
nullptr
||
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
// auto* desc = node->Var();
bool
flag
=
NodeCanReused
(
*
node
->
Var
());
bool
flag
=
true
;
// op output force generated in cpu, can not be reused.
for
(
auto
*
op
:
node
->
inputs
)
{
for
(
auto
*
op
:
node
->
inputs
)
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
// op output force generated in cpu, can not be reused.
flag
&=
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
flag
&=
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
}
}
}
}
// var desc validation.
flag
&=
NodeCanReused
(
*
node
->
Var
());
return
flag
;
return
flag
;
}
}
bool
NodeCanReused
(
const
VarDesc
&
node
)
{
bool
NodeCanReused
(
const
VarDesc
&
node
)
{
auto
type
=
node
.
GetType
();
auto
type
=
node
.
GetType
();
if
(
node
.
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
if
(
!
(
type
==
proto
::
VarType
::
LOD_TENSOR
||
node
.
GetShape
().
empty
())
{
type
==
proto
::
VarType
::
SELECTED_ROWS
||
type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
))
{
return
false
;
}
if
(
node
.
Persistable
()
||
node
.
GetShape
().
empty
())
{
return
false
;
return
false
;
}
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
...
@@ -193,6 +301,174 @@ bool OpHasSubBlock(OpDesc* desc) {
...
@@ -193,6 +301,174 @@ bool OpHasSubBlock(OpDesc* desc) {
return
false
;
return
false
;
}
}
ControlFlowGraph
::
ControlFlowGraph
(
const
ir
::
Graph
&
graph
)
{
ops_
=
SortOpLikeDescOrder
(
graph
);
ConnectNodes
();
}
void
ControlFlowGraph
::
BuildCFGGraph
()
{
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for
(
ir
::
Node
*
op
:
ops_
)
{
for
(
auto
&
input_var
:
op
->
inputs
)
{
if
(
!
input_var
->
inputs
.
empty
())
{
PADDLE_ENFORCE
(
input_var
->
inputs
.
size
()
==
1
&&
input_var
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
auto
*
pred_op
=
input_var
->
inputs
[
0
];
if
(
pred_op
->
Op
()
!=
nullptr
)
{
predecessors_
[
op
].
insert
(
pred_op
);
successors_
[
pred_op
].
insert
(
op
);
}
}
if
(
input_var
->
IsVar
()
&&
!
input_var
->
IsCtrlVar
())
{
uses_
[
op
].
insert
(
input_var
->
Name
());
}
}
for
(
auto
&
output_var
:
op
->
outputs
)
{
// output var may be used by many op
for
(
auto
*
succ_op
:
output_var
->
outputs
)
{
if
(
succ_op
->
Op
()
!=
nullptr
)
{
successors_
[
op
].
insert
(
succ_op
);
predecessors_
[
succ_op
].
insert
(
op
);
}
}
if
(
output_var
->
IsVar
()
&&
!
output_var
->
IsCtrlVar
())
{
defs_
[
op
].
insert
(
output_var
->
Name
());
}
}
}
}
void
ControlFlowGraph
::
ConnectNodes
()
{
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
auto
&
op
=
ops_
[
i
];
try
{
auto
&
next_op
=
ops_
.
at
(
i
+
1
);
successors_
[
op
].
insert
(
next_op
);
predecessors_
[
next_op
].
insert
(
op
);
}
catch
(...)
{
// do nothing
}
FilterVariables
(
op
->
inputs
,
[
&
](
ir
::
Node
*
var
)
{
uses_
[
op
].
emplace
(
var
->
Name
());
});
FilterVariables
(
op
->
outputs
,
[
&
](
ir
::
Node
*
var
)
{
defs_
[
op
].
emplace
(
var
->
Name
());
});
}
}
void
ControlFlowGraph
::
LiveVariableAnalysis
()
{
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std
::
list
<
ir
::
Node
*>
work_list
(
ops_
.
rbegin
(),
ops_
.
rend
());
while
(
!
work_list
.
empty
())
{
ir
::
Node
*
op
=
work_list
.
front
();
work_list
.
pop_front
();
// get the live_in calculated before. Empty if first.
auto
prev_live_in
=
std
::
move
(
live_in_
[
op
]);
for
(
auto
&
s
:
successors_
[
op
])
{
for
(
auto
&
var
:
live_in_
[
s
])
{
live_out_
[
op
].
insert
(
var
);
}
}
for
(
auto
&
var
:
uses_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
live_out_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
defs_
[
op
])
{
live_in_
[
op
].
erase
(
var
);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if
(
live_in_
[
op
]
!=
prev_live_in
)
{
for
(
auto
&
pre
:
predecessors_
[
op
])
{
work_list
.
push_back
(
pre
);
}
}
}
}
void
ControlFlowGraph
::
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
)
{
// update graph from begin idx to the end
for
(
size_t
i
=
begin_idx
;
i
!=
ops_
.
size
();
++
i
)
{
auto
*
op
=
ops_
[
i
];
if
(
uses_
[
op
].
find
(
old_node
)
!=
uses_
[
op
].
end
())
{
uses_
[
op
].
erase
(
old_node
);
uses_
[
op
].
insert
(
new_node
);
}
if
(
defs_
[
op
].
find
(
old_node
)
!=
defs_
[
op
].
end
())
{
defs_
[
op
].
erase
(
old_node
);
defs_
[
op
].
insert
(
new_node
);
}
if
(
live_in_
[
op
].
find
(
old_node
)
!=
live_in_
[
op
].
end
())
{
live_in_
[
op
].
erase
(
old_node
);
live_in_
[
op
].
insert
(
new_node
);
}
if
(
live_out_
[
op
].
find
(
old_node
)
!=
live_out_
[
op
].
end
())
{
live_out_
[
op
].
erase
(
old_node
);
live_out_
[
op
].
insert
(
new_node
);
}
}
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveIn
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_in_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_in_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_in, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveOut
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_out_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_out_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
Use
(
ir
::
Node
*
op
)
const
{
auto
it
=
uses_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
uses_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
vector
<
ir
::
Node
*>
ControlFlowGraph
::
Ops
()
const
{
return
ops_
;
}
std
::
vector
<
ir
::
Node
*>&
ControlFlowGraph
::
Ops
()
{
return
ops_
;
}
ir
::
Node
*
ControlFlowGraph
::
GetNodeByName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
{
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ops_
)
{
if
(
node
==
op
)
break
;
for
(
auto
&
output
:
node
->
outputs
)
{
if
(
output
->
Name
()
==
name
)
{
found_node
=
output
;
}
}
}
return
found_node
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/memory_optimize_helper.h
浏览文件 @
381f2015
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
#include <list>
#include <list>
#include <map>
#include <set>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
...
@@ -27,41 +29,41 @@ namespace paddle {
...
@@ -27,41 +29,41 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kFetchedVars
[]
=
"fetched_vars"
;
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
constexpr
char
kGraphNodePool
[]
=
"graph_node_pool"
;
// NOTE(dzh): Variable and the operators use the var.
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
// for early delete pass.
// Because analysis var pass build base on ir::Node, which maybe released
// or modified between passes, so we use OpDesc* to mark ops.
using
GraphNodePool
=
std
::
vector
<
std
::
pair
<
std
::
string
/*var node*/
,
std
::
unordered_set
<
OpDesc
*>
/* ops */
>>
;
// NOTE(dzh): by default, it sort node in ascend order(by node bytes size).
// NOTE(dzh): A ordered set for node reuse in memory optimize.
// in fluid, -1 means the batch_size is determined in runtime.
// the orderedset sort node in ascend order(by node bytes size).
// the node batch_size equal -1 always ranking in the front than the node not.
// in fluid, -1 means the batch_size, which is determined in runtime.
// So the reuse happens between nodes who's batch_size both are -1
// simultaneously or not.
//
// sort rule:
// rule 0 : smaller node ranking in front.
// rule 1 : batch_size equal -1 ranking in the front than the node not.
//
// For example,
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
class
OrderedNodeList
{
public:
using
NodePair
=
std
::
pair
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
;
using
Iter
=
typename
std
::
list
<
NodePair
>::
iterator
;
using
ConstIter
=
typename
std
::
list
<
NodePair
>::
const_iterator
;
void
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
);
class
OrderedSet
{
public:
// nodes with same name exists in pool.
using
NodeVector
=
std
::
vector
<
ir
::
Node
*>
;
using
Iter
=
typename
std
::
list
<
NodeVector
>::
iterator
;
using
ConstIter
=
typename
std
::
list
<
NodeVector
>::
const_iterator
;
void
Insert
(
ir
::
Node
*
var
);
void
Erase
(
ir
::
Node
*
var
);
void
Erase
(
ir
::
Node
*
var
);
bool
Has
(
ir
::
Node
*
var
)
const
;
void
Erase
(
const
std
::
string
&
var
);
void
Clear
()
{
mark_table_
.
clear
();
bool
Has
(
ir
::
Node
*
var
)
{
return
mark_table_
.
count
(
var
->
Name
());
}
nodes_
.
clear
();
}
bool
Has
(
const
std
::
string
&
var
)
{
return
mark_table_
.
count
(
var
);
}
// find the bestfit shape node block with var.
ir
::
Node
*
FindBestFitNode
(
ir
::
Node
*
var
)
const
;
ir
::
Node
*
NodeMatch
(
ir
::
Node
*
var
)
const
;
// map store non-const iterator, can not promise const
// map store non-const iterator, can not promise const
int
Get
Index
(
ir
::
Node
*
var
);
int
Get
NodeIndexInPool
(
ir
::
Node
*
var
);
// pool all node to string
// pool all node to string
std
::
string
ToString
()
const
;
std
::
string
ToString
()
const
;
...
@@ -69,18 +71,54 @@ class OrderedNodeList {
...
@@ -69,18 +71,54 @@ class OrderedNodeList {
Iter
end
()
{
return
nodes_
.
end
();
}
Iter
end
()
{
return
nodes_
.
end
();
}
ConstIter
begin
()
const
{
return
nodes_
.
begin
();
}
ConstIter
begin
()
const
{
return
nodes_
.
begin
();
}
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
size_t
size
()
const
{
return
nodes_
.
size
();
}
void
Clear
()
{
size_t
size
()
const
{
return
nodes_
.
size
();
}
mark_table_
.
clear
();
nodes_
.
clear
();
}
private:
private:
// for searching.
// for searching.
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
// node swap pairs. var -> ops dep var
// node pool
std
::
list
<
NodePair
>
nodes_
;
std
::
list
<
NodeVector
>
nodes_
;
};
class
ControlFlowGraph
{
public:
ControlFlowGraph
()
=
default
;
// IR Graph
explicit
ControlFlowGraph
(
const
ir
::
Graph
&
graph
);
void
LiveVariableAnalysis
();
void
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
);
const
std
::
set
<
std
::
string
>
LiveIn
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
LiveOut
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
Use
(
ir
::
Node
*
op
)
const
;
const
std
::
vector
<
ir
::
Node
*>
Ops
()
const
;
std
::
vector
<
ir
::
Node
*>&
Ops
();
// for ssa-graph nodes
ir
::
Node
*
GetNodeByName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
;
private:
void
BuildCFGGraph
();
void
ConnectNodes
();
using
NodeListMap
=
std
::
unordered_map
<
ir
::
Node
*
,
std
::
set
<
ir
::
Node
*>>
;
using
VarSetMap
=
std
::
map
<
ir
::
Node
*
,
std
::
set
<
std
::
string
>>
;
// successors ops use the output variables.
NodeListMap
successors_
;
// predecessors ops generated input variables.
NodeListMap
predecessors_
;
// variables lived before run current op.
VarSetMap
live_in_
;
// variables lived after run current op.
VarSetMap
live_out_
;
VarSetMap
uses_
;
// op inputs
VarSetMap
defs_
;
// op outputs
std
::
vector
<
ir
::
Node
*>
ops_
;
// op sequence by topology sort
};
};
// valid a tensor can be reuse or not
// valid a tensor can be reuse or not
...
@@ -93,15 +131,24 @@ bool NodeCanReused(const VarDesc& node);
...
@@ -93,15 +131,24 @@ bool NodeCanReused(const VarDesc& node);
bool
OpHasSubBlock
(
OpDesc
*
desc
);
bool
OpHasSubBlock
(
OpDesc
*
desc
);
// node memory size in bytes
// node memory size in bytes
size_t
NodeSize
InBytes
(
ir
::
Node
*
n
);
size_t
NodeSize
(
ir
::
Node
*
n
);
// node memory size in bytes
// node memory size in bytes
size_t
NodeSize
InBytes
(
const
VarDesc
&
);
size_t
NodeSize
(
const
VarDesc
&
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
// NOTE(dzhwinter)
// after node reuse, the replaced node shape is
// different with its VarDesc. So need to find the
// correct VarDesc in Block.
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
static
inline
bool
IsSameDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
template
<
typename
Container
,
typename
Callback
>
template
<
typename
Container
,
typename
Callback
>
class
FilterVariableImpl
{
class
FilterVariableImpl
{
public:
public:
...
...
paddle/fluid/framework/details/memory_optimize_helper_test.cc
浏览文件 @
381f2015
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
#include <iterator>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
...
@@ -22,13 +23,19 @@
...
@@ -22,13 +23,19 @@
#include <vector>
#include <vector>
#include "glog/logging.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
TEST
(
Ordered
NodeLis
t
,
Normal
)
{
TEST
(
Ordered
Se
t
,
Normal
)
{
Ordered
NodeLis
t
pool
;
Ordered
Se
t
pool
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
// clang-format off
// clang-format off
...
@@ -56,8 +63,15 @@ TEST(OrderedNodeList, Normal) {
...
@@ -56,8 +63,15 @@ TEST(OrderedNodeList, Normal) {
nodes
.
emplace_back
(
std
::
move
(
node
));
nodes
.
emplace_back
(
std
::
move
(
node
));
}
}
// Insert
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
pool
.
Insert
(
node
.
get
(),
op
.
get
());
pool
.
Insert
(
node
.
get
());
}
// Has/size
ASSERT_EQ
(
pool
.
size
(),
shapes
.
size
());
for
(
auto
&
node
:
nodes
)
{
ASSERT_TRUE
(
pool
.
Has
(
node
.
get
()));
}
}
// assert its order and interface.
// assert its order and interface.
...
@@ -66,14 +80,14 @@ TEST(OrderedNodeList, Normal) {
...
@@ -66,14 +80,14 @@ TEST(OrderedNodeList, Normal) {
std
::
cout
<<
pool
.
ToString
()
<<
std
::
endl
;
std
::
cout
<<
pool
.
ToString
()
<<
std
::
endl
;
ASSERT_EQ
(
pool
.
size
(),
static_cast
<
size_t
>
(
COUNT
-
1
));
ASSERT_EQ
(
pool
.
size
(),
static_cast
<
size_t
>
(
COUNT
-
1
));
ASSERT_EQ
(
pool
.
Get
Index
(
nodes
.
back
().
get
()),
0
);
ASSERT_EQ
(
pool
.
Get
NodeIndexInPool
(
nodes
.
back
().
get
()),
0
);
{
{
auto
v1
=
block_desc
->
Var
(
"11"
);
auto
v1
=
block_desc
->
Var
(
"11"
);
v1
->
SetShape
({
-
1
,
256
,
56
,
56
});
v1
->
SetShape
({
-
1
,
256
,
56
,
56
});
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v1
);
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v1
);
node1
->
inputs
.
emplace_back
(
op
.
get
());
node1
->
inputs
.
emplace_back
(
op
.
get
());
auto
*
cache
=
pool
.
NodeMatch
(
node1
.
get
());
auto
*
cache
=
pool
.
FindBestFitNode
(
node1
.
get
());
ASSERT_EQ
(
cache
,
nullptr
);
ASSERT_EQ
(
cache
,
nullptr
);
}
}
{
{
...
@@ -81,16 +95,401 @@ TEST(OrderedNodeList, Normal) {
...
@@ -81,16 +95,401 @@ TEST(OrderedNodeList, Normal) {
v2
->
SetShape
({
-
1
,
2
,
5
});
v2
->
SetShape
({
-
1
,
2
,
5
});
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v2
);
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v2
);
node1
->
inputs
.
emplace_back
(
op
.
get
());
node1
->
inputs
.
emplace_back
(
op
.
get
());
auto
*
cache
=
pool
.
NodeMatch
(
node1
.
get
());
auto
*
cache
=
pool
.
FindBestFitNode
(
node1
.
get
());
ASSERT_EQ
(
pool
.
Get
Index
(
cache
),
2
);
// match 6:[-1,2,5]
ASSERT_EQ
(
pool
.
Get
NodeIndexInPool
(
cache
),
2
);
// match 6:[-1,2,5]
}
}
{
{
auto
v3
=
block_desc
->
Var
(
"13"
);
auto
v3
=
block_desc
->
Var
(
"13"
);
v3
->
SetShape
({
2
,
5
});
v3
->
SetShape
({
2
,
5
});
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v3
);
std
::
unique_ptr
<
ir
::
Node
>
node1
=
ir
::
CreateNodeForTest
(
v3
);
node1
->
inputs
.
emplace_back
(
op
.
get
());
node1
->
inputs
.
emplace_back
(
op
.
get
());
auto
*
cache
=
pool
.
NodeMatch
(
node1
.
get
());
auto
*
cache
=
pool
.
FindBestFitNode
(
node1
.
get
());
ASSERT_EQ
(
pool
.
GetIndex
(
cache
),
5
);
// match 4:[5,2]
ASSERT_EQ
(
pool
.
GetNodeIndexInPool
(
cache
),
5
);
// match 4:[5,2]
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
assign
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
AssignOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
dummy
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace
paddle
{
namespace
framework
{
namespace
details
{
inline
static
ProgramDesc
FillProgramDesc
()
{
ProgramDesc
prog
;
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"d"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"e"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"c"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"d"
});
op
->
SetOutput
(
"Out"
,
{
"e"
});
}
return
prog
;
}
TEST
(
CFGGraph
,
IRGraph
)
{
// prepare ir graph
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
ControlFlowGraph
cfg
(
graph
);
cfg
.
LiveVariableAnalysis
();
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
0
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
0
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
1
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
1
])));
// test sum op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
2
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
2
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
3
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
3
])));
}
// 1. normal test
TEST
(
SortOpLikeDescOrder
,
NormalTest
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 2. remove some op_desc
TEST
(
SortOpLikeDescOrder
,
RemoveOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
graph
.
Nodes
();
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
outputs
.
back
()
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
graph
.
RemoveNode
(
e
);
// other node keeps the same order
auto
remain_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
remain_nodes
.
size
();
++
i
)
{
auto
node
=
remain_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 3. add some op_desc
TEST
(
SortOpLikeDescOrder
,
AddOpDesc
)
{
auto
prog
=
FillProgramDesc
();
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
ir
::
Graph
graph
(
prog
);
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
op_descs
.
insert
(
op_descs
.
begin
()
+
4
,
op
);
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 4. add and delete some op_desc
TEST
(
SortOpLikeDescOrder
,
AddAndDeleteOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// remove sum node
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
auto
nodes
=
graph
.
Nodes
();
for
(
auto
node
:
nodes
)
{
if
(
node
->
Name
()
==
"sum"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
c
->
outputs
.
begin
(),
c
->
outputs
.
end
(),
found_node
);
ir
::
Node
*
pending_op
=
found_node
->
outputs
[
0
]
->
outputs
[
0
];
graph
.
RemoveNode
(
e
);
graph
.
RemoveNode
(
pending_op
);
graph
.
RemoveNode
(
found_node
);
}
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
insert
(
op_descs
.
begin
()
+
2
,
op
);
// check the order
auto
mynodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
mynodes
.
size
();
++
i
)
{
auto
node
=
mynodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 5. add and replace some op_desc inplace.
TEST
(
SortOpLikeDescOrder
,
AddAndReplaceOpDescInplace
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
emplace_back
(
op
);
// replace op_desc inplace
auto
nodes
=
graph
.
Nodes
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
&&
node
->
Name
()
==
"assign"
)
{
if
(
node
->
outputs
.
size
()
==
1
&&
node
->
outputs
[
0
]
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
e
->
inputs
.
begin
(),
e
->
inputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
}
op_descs
.
erase
(
op_descs
.
begin
()
+
3
);
auto
replace_op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
replace_op
->
SetType
(
"sum"
);
replace_op
->
SetInput
(
"X"
,
{
"d"
,
"d1"
});
replace_op
->
SetOutput
(
"Out"
,
{
"e"
});
{
ir
::
Node
*
sum2
=
graph
.
CreateOpNode
(
replace_op
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
d1
=
find_node_in_graph
(
"d1"
);
sum2
->
inputs
.
emplace_back
(
d
);
sum2
->
inputs
.
emplace_back
(
d1
);
sum2
->
outputs
.
emplace_back
(
e
);
e
->
inputs
.
emplace_back
(
sum2
);
d
->
outputs
.
emplace_back
(
sum2
);
d1
->
outputs
.
emplace_back
(
sum2
);
}
op_descs
.
emplace_back
(
replace_op
);
// compare op order
auto
graph_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
graph_nodes
.
size
();
++
i
)
{
auto
node
=
graph_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
}
}
...
...
paddle/fluid/framework/details/memory_optimize_pass.cc
浏览文件 @
381f2015
...
@@ -43,11 +43,6 @@ namespace paddle {
...
@@ -43,11 +43,6 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
inline
bool
IsSameDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
nodes
=
graph
->
Nodes
();
auto
nodes
=
graph
->
Nodes
();
...
@@ -77,7 +72,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -77,7 +72,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
if
(
!
NodeCanReused
(
var
)
||
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
||
if
(
!
NodeCanReused
(
var
)
||
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
||
skip_set_
.
count
(
var
->
Name
()))
skip_set_
.
count
(
var
->
Name
()))
continue
;
continue
;
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
ir
::
Node
*
cache
=
pool_
.
FindBestFitNode
(
var
);
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
...
@@ -95,11 +90,12 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -95,11 +90,12 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
<<
"replace it again. Skip this candidate."
;
<<
"replace it again. Skip this candidate."
;
continue
;
continue
;
int
node_idx_in_pool
=
pool_
.
Get
Index
(
cache
);
int
node_idx_in_pool
=
pool_
.
Get
NodeIndexInPool
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
std
::
to_string
(
reuse_id
++
),
DebugString
(
var
),
DebugString
(
cache
),
std
::
to_string
(
reuse_id
++
),
DebugString
(
var
),
DebugString
(
cache
),
node_idx_in_pool
,
static_cast
<
int
>
(
pool_
.
size
()));
node_idx_in_pool
,
static_cast
<
int
>
(
pool_
.
size
()));
// update CFG Graph on the fly.
// update CFG Graph on the fly.
// reused var maybe re-fill into the pool
// reused var maybe re-fill into the pool
cfg_
->
RenameVarInCFGGraph
(
var
->
Name
(),
cache
->
Name
(),
idx
);
cfg_
->
RenameVarInCFGGraph
(
var
->
Name
(),
cache
->
Name
(),
idx
);
...
@@ -112,6 +108,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -112,6 +108,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
pool_
.
Erase
(
cache
);
pool_
.
Erase
(
cache
);
}
}
// fill the pool
// fill the pool
std
::
unordered_set
<
std
::
string
>
unlived_vars
;
std
::
unordered_set
<
std
::
string
>
unlived_vars
;
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
...
@@ -120,36 +117,15 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
...
@@ -120,36 +117,15 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
}
}
for
(
auto
var
:
unlived_vars
)
{
for
(
auto
var
:
unlived_vars
)
{
ir
::
Node
*
var_node
=
cfg_
->
GetNode
FromVar
Name
(
var
,
op
);
ir
::
Node
*
var_node
=
cfg_
->
GetNode
By
Name
(
var
,
op
);
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
pool_
.
Insert
(
var_node
,
op
);
pool_
.
Insert
(
var_node
);
}
}
}
}
}
}
}
}
graph
->
ResolveHazard
(
var_nodes_
);
graph
->
ResolveHazard
(
var_nodes_
);
// For early delete pass. use GraphNodePool load the unlived vars.
// 1. find all deps op for each unlived var in memory pool.
for
(
auto
&
op
:
graph
->
Nodes
())
{
for
(
auto
&
var
:
op
->
inputs
)
{
if
(
pool_
.
Has
(
var
))
{
pool_
.
Insert
(
var
,
op
);
}
}
}
// 2. convert ir node based memory pool to graph node
// because Node* maybe released bettwen passes.
auto
&
graph_pool
=
graph
->
Get
<
GraphNodePool
>
(
kGraphNodePool
);
for
(
auto
it
=
pool_
.
begin
();
it
!=
pool_
.
end
();
++
it
)
{
std
::
unordered_set
<
OpDesc
*>
descs
;
for
(
auto
&
op
:
it
->
second
)
{
PADDLE_ENFORCE
(
op
->
IsOp
());
descs
.
insert
(
op
->
Op
());
}
graph_pool
.
push_back
(
std
::
make_pair
(
it
->
first
->
Name
(),
descs
));
}
return
graph
;
return
graph
;
}
}
...
@@ -198,12 +174,12 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
...
@@ -198,12 +174,12 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
PADDLE_ENFORCE
(
sub_op
!=
nullptr
);
PADDLE_ENFORCE
(
sub_op
!=
nullptr
);
for
(
auto
*
var
:
sub_op
->
outputs
)
{
for
(
auto
*
var
:
sub_op
->
outputs
)
{
if
(
NodeCanReused
(
var
))
{
if
(
NodeCanReused
(
var
))
{
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
ir
::
Node
*
cache
=
pool_
.
FindBestFitNode
(
var
);
if
(
cache
!=
nullptr
)
{
if
(
cache
!=
nullptr
)
{
if
(
var
->
Var
()
->
GetDataType
()
!=
cache
->
Var
()
->
GetDataType
())
{
if
(
var
->
Var
()
->
GetDataType
()
!=
cache
->
Var
()
->
GetDataType
())
{
continue
;
continue
;
}
}
int
node_idx_in_pool
=
pool_
.
Get
Index
(
cache
);
int
node_idx_in_pool
=
pool_
.
Get
NodeIndexInPool
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
std
::
to_string
(
sub_reuse_id
++
),
DebugString
(
var
),
std
::
to_string
(
sub_reuse_id
++
),
DebugString
(
var
),
...
@@ -342,267 +318,10 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
...
@@ -342,267 +318,10 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
var_nodes_
.
at
(
var
).
clear
();
var_nodes_
.
at
(
var
).
clear
();
}
}
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
"Graph has no attribute of kAllOpDescs."
);
// 1. get op desc order
auto
&
op_descs
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
);
// 2. topology sort order
auto
nodes
=
graph
.
Nodes
();
std
::
deque
<
ir
::
Node
*>
ops
;
FilterVariables
(
nodes
,
[
&
](
ir
::
Node
*
op
)
{
if
(
op
->
IsOp
()
&&
op
->
Op
()
!=
nullptr
)
{
ops
.
emplace_back
(
op
);
}
});
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
list
<
ir
::
Node
*>
ready_ops
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
for
(
auto
*
op
:
ops
)
{
std
::
unordered_set
<
ir
::
Node
*>
preceding_op
;
for
(
auto
*
in
:
op
->
inputs
)
{
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
());
preceding_op
.
emplace
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
emplace
(
op
);
}
op_deps
[
op
]
=
preceding_op
.
size
();
if
(
preceding_op
.
empty
())
{
ready_ops
.
emplace_back
(
op
);
}
}
// 3. generated op list based desc order and the topology order
std
::
vector
<
ir
::
Node
*>
ret
;
std
::
list
<
OpDesc
*>
op_descs_list
(
op_descs
.
begin
(),
op_descs
.
end
());
auto
update_by_found_node
=
[
&
](
ir
::
Node
*
found_node
)
{
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
[
pending_op
]
==
0
)
{
ready_ops
.
emplace_back
(
pending_op
);
}
}
ready_ops
.
remove
(
found_node
);
ret
.
emplace_back
(
found_node
);
};
while
(
!
ready_ops
.
empty
())
{
bool
all_of_ready_op_unmatched
=
true
;
for
(
auto
it
=
op_descs_list
.
begin
();
it
!=
op_descs_list
.
end
();)
{
auto
op_desc
=
*
it
;
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
op
:
ready_ops
)
{
if
(
IsSameDesc
(
op
->
Op
(),
op_desc
))
{
found_node
=
op
;
break
;
}
}
// 3.1 op desc deleted by other pass
if
(
found_node
==
nullptr
)
{
++
it
;
continue
;
}
else
{
all_of_ready_op_unmatched
=
false
;
it
=
op_descs_list
.
erase
(
it
);
}
update_by_found_node
(
found_node
);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std
::
list
<
ir
::
Node
*>
prev_ready_ops
(
ready_ops
);
if
(
all_of_ready_op_unmatched
)
{
for
(
auto
op
:
prev_ready_ops
)
{
update_by_found_node
(
op
);
}
}
}
PADDLE_ENFORCE
(
std
::
all_of
(
op_deps
.
begin
(),
op_deps
.
end
(),
[
&
](
const
std
::
pair
<
ir
::
Node
*
,
size_t
>&
p
)
{
return
p
.
second
==
0
;
}));
return
ret
;
}
ControlFlowGraph
::
ControlFlowGraph
(
const
ir
::
Graph
&
graph
)
{
ops_
=
SortOpLikeDescOrder
(
graph
);
ConnectNodes
();
}
void
ControlFlowGraph
::
BuildCFGGraph
()
{
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for
(
ir
::
Node
*
op
:
ops_
)
{
for
(
auto
&
input_var
:
op
->
inputs
)
{
if
(
!
input_var
->
inputs
.
empty
())
{
PADDLE_ENFORCE
(
input_var
->
inputs
.
size
()
==
1
&&
input_var
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
auto
*
pred_op
=
input_var
->
inputs
[
0
];
if
(
pred_op
->
Op
()
!=
nullptr
)
{
predecessors_
[
op
].
insert
(
pred_op
);
successors_
[
pred_op
].
insert
(
op
);
}
}
if
(
input_var
->
IsVar
()
&&
!
input_var
->
IsCtrlVar
())
{
uses_
[
op
].
insert
(
input_var
->
Name
());
}
}
for
(
auto
&
output_var
:
op
->
outputs
)
{
// output var may be used by many op
for
(
auto
*
succ_op
:
output_var
->
outputs
)
{
if
(
succ_op
->
Op
()
!=
nullptr
)
{
successors_
[
op
].
insert
(
succ_op
);
predecessors_
[
succ_op
].
insert
(
op
);
}
}
if
(
output_var
->
IsVar
()
&&
!
output_var
->
IsCtrlVar
())
{
defs_
[
op
].
insert
(
output_var
->
Name
());
}
}
}
}
void
ControlFlowGraph
::
ConnectNodes
()
{
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
auto
&
op
=
ops_
[
i
];
try
{
auto
&
next_op
=
ops_
.
at
(
i
+
1
);
successors_
[
op
].
insert
(
next_op
);
predecessors_
[
next_op
].
insert
(
op
);
}
catch
(...)
{
// do nothing
}
FilterVariables
(
op
->
inputs
,
[
&
](
ir
::
Node
*
var
)
{
uses_
[
op
].
emplace
(
var
->
Name
());
});
FilterVariables
(
op
->
outputs
,
[
&
](
ir
::
Node
*
var
)
{
defs_
[
op
].
emplace
(
var
->
Name
());
});
}
}
void
ControlFlowGraph
::
LiveVariableAnalysis
()
{
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std
::
list
<
ir
::
Node
*>
work_list
(
ops_
.
rbegin
(),
ops_
.
rend
());
while
(
!
work_list
.
empty
())
{
ir
::
Node
*
op
=
work_list
.
front
();
work_list
.
pop_front
();
// get the live_in calculated before. Empty if first.
auto
prev_live_in
=
std
::
move
(
live_in_
[
op
]);
for
(
auto
&
s
:
successors_
[
op
])
{
for
(
auto
&
var
:
live_in_
[
s
])
{
live_out_
[
op
].
insert
(
var
);
}
}
for
(
auto
&
var
:
uses_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
live_out_
[
op
])
{
live_in_
[
op
].
insert
(
var
);
}
for
(
auto
&
var
:
defs_
[
op
])
{
live_in_
[
op
].
erase
(
var
);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if
(
live_in_
[
op
]
!=
prev_live_in
)
{
for
(
auto
&
pre
:
predecessors_
[
op
])
{
work_list
.
push_back
(
pre
);
}
}
}
}
void
ControlFlowGraph
::
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
)
{
// update graph from begin idx to the end
for
(
size_t
i
=
begin_idx
;
i
!=
ops_
.
size
();
++
i
)
{
auto
*
op
=
ops_
[
i
];
if
(
uses_
[
op
].
find
(
old_node
)
!=
uses_
[
op
].
end
())
{
uses_
[
op
].
erase
(
old_node
);
uses_
[
op
].
insert
(
new_node
);
}
if
(
defs_
[
op
].
find
(
old_node
)
!=
defs_
[
op
].
end
())
{
defs_
[
op
].
erase
(
old_node
);
defs_
[
op
].
insert
(
new_node
);
}
if
(
live_in_
[
op
].
find
(
old_node
)
!=
live_in_
[
op
].
end
())
{
live_in_
[
op
].
erase
(
old_node
);
live_in_
[
op
].
insert
(
new_node
);
}
if
(
live_out_
[
op
].
find
(
old_node
)
!=
live_out_
[
op
].
end
())
{
live_out_
[
op
].
erase
(
old_node
);
live_out_
[
op
].
insert
(
new_node
);
}
}
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveIn
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_in_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_in_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_in, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
LiveOut
(
ir
::
Node
*
op
)
const
{
auto
it
=
live_out_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
live_out_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
set
<
std
::
string
>
ControlFlowGraph
::
Use
(
ir
::
Node
*
op
)
const
{
auto
it
=
uses_
.
find
(
op
);
PADDLE_ENFORCE
(
it
!=
uses_
.
end
(),
string
::
Sprintf
(
"Expect %s in live_out, but Not Found."
,
op
->
Name
()));
return
it
->
second
;
}
const
std
::
vector
<
ir
::
Node
*>
ControlFlowGraph
::
Ops
()
const
{
return
ops_
;
}
std
::
vector
<
ir
::
Node
*>&
ControlFlowGraph
::
Ops
()
{
return
ops_
;
}
ir
::
Node
*
ControlFlowGraph
::
GetNodeFromVarName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
{
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ops_
)
{
if
(
node
==
op
)
break
;
for
(
auto
&
output
:
node
->
outputs
)
{
if
(
output
->
Name
()
==
name
)
{
found_node
=
output
;
}
}
}
return
found_node
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
memory_optimize_pass
,
REGISTER_PASS
(
memory_optimize_pass
,
paddle
::
framework
::
details
::
MemoryOptimizePass
)
paddle
::
framework
::
details
::
MemoryOptimizePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
paddle/fluid/framework/details/memory_optimize_pass.h
浏览文件 @
381f2015
...
@@ -32,20 +32,15 @@
...
@@ -32,20 +32,15 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
class
ControlFlowGraph
;
class
MemoryOptimizePass
:
public
ir
::
Pass
{
class
MemoryOptimizePass
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
private:
// fill the variable map(var_nodes) by version.
// fill the variable map(var_nodes) by version.
void
InitSSAGraphNodes
()
const
;
void
InitSSAGraphNodes
()
const
;
private:
// update program descs
// update program descs
void
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
void
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
size_t
idx
)
const
;
const
std
::
string
&
cache_var
,
size_t
idx
)
const
;
...
@@ -62,7 +57,7 @@ class MemoryOptimizePass : public ir::Pass {
...
@@ -62,7 +57,7 @@ class MemoryOptimizePass : public ir::Pass {
private:
private:
// Reuse Node Pool, Owned.
// Reuse Node Pool, Owned.
mutable
Ordered
NodeLis
t
pool_
;
mutable
Ordered
Se
t
pool_
;
// controlflow Graph
// controlflow Graph
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
// skip set
// skip set
...
@@ -71,45 +66,6 @@ class MemoryOptimizePass : public ir::Pass {
...
@@ -71,45 +66,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
};
};
class
ControlFlowGraph
{
public:
ControlFlowGraph
()
=
default
;
// For IR Graph in parallelexecutor
explicit
ControlFlowGraph
(
const
ir
::
Graph
&
graph
);
void
LiveVariableAnalysis
();
void
RenameVarInCFGGraph
(
const
std
::
string
&
old_node
,
const
std
::
string
&
new_node
,
int
begin_idx
);
const
std
::
set
<
std
::
string
>
LiveIn
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
LiveOut
(
ir
::
Node
*
op
)
const
;
const
std
::
set
<
std
::
string
>
Use
(
ir
::
Node
*
op
)
const
;
const
std
::
vector
<
ir
::
Node
*>
Ops
()
const
;
std
::
vector
<
ir
::
Node
*>&
Ops
();
// for ssa-graph nodes
ir
::
Node
*
GetNodeFromVarName
(
const
std
::
string
&
name
,
ir
::
Node
*
op
)
const
;
private:
void
BuildCFGGraph
();
void
ConnectNodes
();
using
NodeListMap
=
std
::
unordered_map
<
ir
::
Node
*
,
std
::
set
<
ir
::
Node
*>>
;
using
VarSetMap
=
std
::
map
<
ir
::
Node
*
,
std
::
set
<
std
::
string
>>
;
// successors ops use the output variables.
NodeListMap
successors_
;
// predecessors ops generated input variables.
NodeListMap
predecessors_
;
// variables lived before run current op.
VarSetMap
live_in_
;
// variables lived after run current op.
VarSetMap
live_out_
;
VarSetMap
uses_
;
// op inputs
VarSetMap
defs_
;
// op outputs
std
::
vector
<
ir
::
Node
*>
ops_
;
// op sequence by topology sort
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/memory_optimize_pass_test.cc
已删除
100644 → 0
浏览文件 @
6492ea9c
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
assign
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
AssignOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
REGISTER_OPERATOR
(
dummy
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
inline
bool
IsSameDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
inline
static
ProgramDesc
FillProgramDesc
()
{
ProgramDesc
prog
;
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"d"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"e"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"c"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d"
});
}
{
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
"d"
});
op
->
SetOutput
(
"Out"
,
{
"e"
});
}
return
prog
;
}
TEST
(
CFGGraph
,
IRGraph
)
{
// prepare ir graph
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
ControlFlowGraph
cfg
(
graph
);
cfg
.
LiveVariableAnalysis
();
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
0
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
0
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"a"
,
"b"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
1
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
1
])));
// test sum op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"b"
,
"c"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
2
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
2
])));
// test assign op
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{
"d"
}
==
cfg
.
LiveIn
(
cfg
.
Ops
()[
3
])));
ASSERT_TRUE
((
std
::
set
<
std
::
string
>
{}
==
cfg
.
LiveOut
(
cfg
.
Ops
()[
3
])));
}
// 1. normal test
TEST
(
SortOpLikeDescOrder
,
NormalTest
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 2. remove some op_desc
TEST
(
SortOpLikeDescOrder
,
RemoveOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
nodes
=
graph
.
Nodes
();
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
outputs
.
back
()
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
graph
.
RemoveNode
(
e
);
// other node keeps the same order
auto
remain_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
remain_nodes
.
size
();
++
i
)
{
auto
node
=
remain_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 3. add some op_desc
TEST
(
SortOpLikeDescOrder
,
AddOpDesc
)
{
auto
prog
=
FillProgramDesc
();
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
ir
::
Graph
graph
(
prog
);
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
op_descs
.
insert
(
op_descs
.
begin
()
+
4
,
op
);
auto
nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
node
=
nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 4. add and delete some op_desc
TEST
(
SortOpLikeDescOrder
,
AddAndDeleteOpDesc
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
// remove sum node
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
ir
::
Node
*
found_node
=
nullptr
;
auto
nodes
=
graph
.
Nodes
();
for
(
auto
node
:
nodes
)
{
if
(
node
->
Name
()
==
"sum"
)
{
found_node
=
node
;
break
;
}
}
PADDLE_ENFORCE
(
found_node
!=
nullptr
);
for
(
auto
it
=
op_descs
.
begin
();
it
!=
op_descs
.
end
();)
{
if
(
IsSameDesc
(
*
it
,
found_node
->
Op
()))
{
it
=
op_descs
.
erase
(
it
);
}
else
{
++
it
;
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
c
->
outputs
.
begin
(),
c
->
outputs
.
end
(),
found_node
);
ir
::
Node
*
pending_op
=
found_node
->
outputs
[
0
]
->
outputs
[
0
];
graph
.
RemoveNode
(
e
);
graph
.
RemoveNode
(
pending_op
);
graph
.
RemoveNode
(
found_node
);
}
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
insert
(
op_descs
.
begin
()
+
2
,
op
);
// check the order
auto
mynodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
mynodes
.
size
();
++
i
)
{
auto
node
=
mynodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
// 5. add and replace some op_desc inplace.
TEST
(
SortOpLikeDescOrder
,
AddAndReplaceOpDescInplace
)
{
auto
prog
=
FillProgramDesc
();
ir
::
Graph
graph
(
prog
);
const
std
::
vector
<
OpDesc
*>*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
prog
.
Block
(
0
).
AllOps
());
graph
.
Set
(
details
::
kAllOpDescs
,
all_op_descs
);
// take ownership
auto
find_node_in_graph
=
[
&
](
std
::
string
s
)
{
ir
::
Node
*
ret
=
nullptr
;
for
(
auto
n
:
graph
.
Nodes
())
{
if
(
n
->
Name
()
==
s
)
{
ret
=
n
;
break
;
}
}
PADDLE_ENFORCE
(
ret
!=
nullptr
);
return
ret
;
};
auto
op_descs
=
prog
.
Block
(
0
).
AllOps
();
// add node
auto
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"b"
,
"c"
});
op
->
SetOutput
(
"Out"
,
{
"d1"
});
{
ir
::
Node
*
node
=
graph
.
CreateOpNode
(
op
);
ir
::
Node
*
d1
=
graph
.
CreateVarNode
(
prog
.
MutableBlock
(
0
)
->
Var
(
"d1"
));
ir
::
Node
*
b
=
find_node_in_graph
(
"b"
);
ir
::
Node
*
c
=
find_node_in_graph
(
"c"
);
node
->
outputs
.
emplace_back
(
d1
);
node
->
inputs
.
emplace_back
(
b
);
node
->
inputs
.
emplace_back
(
c
);
d1
->
inputs
.
emplace_back
(
node
);
b
->
outputs
.
emplace_back
(
node
);
c
->
outputs
.
emplace_back
(
node
);
}
op_descs
.
emplace_back
(
op
);
// replace op_desc inplace
auto
nodes
=
graph
.
Nodes
();
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
node
:
nodes
)
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
&&
node
->
Name
()
==
"assign"
)
{
if
(
node
->
outputs
.
size
()
==
1
&&
node
->
outputs
[
0
]
->
Name
()
==
"e"
)
{
found_node
=
node
;
break
;
}
}
}
{
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
std
::
remove
(
d
->
outputs
.
begin
(),
d
->
outputs
.
end
(),
found_node
);
std
::
remove
(
e
->
inputs
.
begin
(),
e
->
inputs
.
end
(),
found_node
);
graph
.
RemoveNode
(
found_node
);
}
op_descs
.
erase
(
op_descs
.
begin
()
+
3
);
auto
replace_op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
replace_op
->
SetType
(
"sum"
);
replace_op
->
SetInput
(
"X"
,
{
"d"
,
"d1"
});
replace_op
->
SetOutput
(
"Out"
,
{
"e"
});
{
ir
::
Node
*
sum2
=
graph
.
CreateOpNode
(
replace_op
);
ir
::
Node
*
e
=
find_node_in_graph
(
"e"
);
ir
::
Node
*
d
=
find_node_in_graph
(
"d"
);
ir
::
Node
*
d1
=
find_node_in_graph
(
"d1"
);
sum2
->
inputs
.
emplace_back
(
d
);
sum2
->
inputs
.
emplace_back
(
d1
);
sum2
->
outputs
.
emplace_back
(
e
);
e
->
inputs
.
emplace_back
(
sum2
);
d
->
outputs
.
emplace_back
(
sum2
);
d1
->
outputs
.
emplace_back
(
sum2
);
}
op_descs
.
emplace_back
(
replace_op
);
// compare op order
auto
graph_nodes
=
SortOpLikeDescOrder
(
graph
);
for
(
size_t
i
=
0
;
i
<
graph_nodes
.
size
();
++
i
)
{
auto
node
=
graph_nodes
[
i
];
auto
op_desc
=
op_descs
[
i
];
ASSERT_TRUE
(
IsSameDesc
(
node
->
Op
(),
op_desc
));
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/sequential_execution_pass.cc
浏览文件 @
381f2015
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/framework/details/sequential_execution_pass.h
浏览文件 @
381f2015
...
@@ -21,8 +21,6 @@ namespace paddle {
...
@@ -21,8 +21,6 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
class
SequentialExecutionPass
:
public
ir
::
Pass
{
class
SequentialExecutionPass
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
...
...
paddle/fluid/framework/inplace_op_inference.h
浏览文件 @
381f2015
...
@@ -69,7 +69,7 @@ class InplaceInToOut : public InplaceOpInference {
...
@@ -69,7 +69,7 @@ class InplaceInToOut : public InplaceOpInference {
bool
TryInplaceInputOutput
(
const
VarDesc
&
in
,
const
VarDesc
&
out
)
const
{
bool
TryInplaceInputOutput
(
const
VarDesc
&
in
,
const
VarDesc
&
out
)
const
{
return
in
.
Name
()
!=
out
.
Name
()
&&
details
::
NodeCanReused
(
in
)
&&
return
in
.
Name
()
!=
out
.
Name
()
&&
details
::
NodeCanReused
(
in
)
&&
details
::
NodeCanReused
(
out
)
&&
details
::
NodeCanReused
(
out
)
&&
details
::
NodeSize
InBytes
(
out
)
<=
details
::
NodeSizeInBytes
(
in
);
details
::
NodeSize
(
out
)
<=
details
::
NodeSize
(
in
);
}
}
};
};
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
381f2015
...
@@ -171,14 +171,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
...
@@ -171,14 +171,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
if
(
build_strategy_
.
memory_early_delete_
)
{
auto
early_delete_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"memory_early_delete_pass"
);
early_delete_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
graph
=
early_delete_pass
->
Apply
(
std
::
move
(
graph
));
}
VLOG
(
10
)
<<
"MemoryEarlyDeletePass Applied."
;
}
}
return
graph
;
return
graph
;
...
@@ -288,6 +280,8 @@ ParallelExecutor::ParallelExecutor(
...
@@ -288,6 +280,8 @@ ParallelExecutor::ParallelExecutor(
graphs
.
push_back
(
std
::
move
(
graph
));
graphs
.
push_back
(
std
::
move
(
graph
));
#endif
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
auto
max_memory_size
=
GetEagerDeletionThreshold
();
VLOG
(
10
)
<<
"Eager Deletion Threshold "
<<
static_cast
<
float
>
(
max_memory_size
)
/
(
1
<<
30
);
if
(
max_memory_size
>=
0
)
{
if
(
max_memory_size
>=
0
)
{
for
(
size_t
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
graphs
[
i
]
=
member_
->
PrepareGCAndRefCnts
(
graphs
[
i
]
=
member_
->
PrepareGCAndRefCnts
(
...
@@ -506,6 +500,5 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -506,6 +500,5 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
memory_early_delete_pass
);
USE_PASS
(
reference_count_pass
);
USE_PASS
(
reference_count_pass
);
USE_PASS
(
eager_deletion_pass
);
USE_PASS
(
eager_deletion_pass
);
paddle/fluid/framework/scope.cc
浏览文件 @
381f2015
...
@@ -22,11 +22,7 @@ limitations under the License. */
...
@@ -22,11 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool
(
benchmark
,
false
,
DECLARE_bool
(
benchmark
);
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
DEFINE_bool
(
DEFINE_bool
(
eager_delete_scope
,
true
,
eager_delete_scope
,
true
,
...
...
paddle/fluid/memory/allocation/legacy_allocator.cc
浏览文件 @
381f2015
...
@@ -36,6 +36,7 @@ DEFINE_bool(init_allocated_mem, false,
...
@@ -36,6 +36,7 @@ DEFINE_bool(init_allocated_mem, false,
"that initializing the allocated memory with a small value "
"that initializing the allocated memory with a small value "
"during unit testing."
);
"during unit testing."
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
DECLARE_bool
(
benchmark
);
namespace
paddle
{
namespace
paddle
{
namespace
memory
{
namespace
memory
{
...
@@ -198,7 +199,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
...
@@ -198,7 +199,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
platform
::
SetDeviceId
(
cur_dev
);
platform
::
SetDeviceId
(
cur_dev
);
}
else
{
}
else
{
if
(
VLOG_IS_ON
(
3
)
)
{
if
(
FLAGS_benchmark
)
{
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
}
}
if
(
FLAGS_init_allocated_mem
)
{
if
(
FLAGS_init_allocated_mem
)
{
...
@@ -216,7 +217,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
...
@@ -216,7 +217,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
size_t
size
)
{
size_t
size
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
if
(
VLOG_IS_ON
(
3
)
)
{
if
(
FLAGS_benchmark
)
{
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
);
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
);
}
}
#else
#else
...
...
paddle/fluid/platform/place.cc
浏览文件 @
381f2015
...
@@ -14,6 +14,12 @@ limitations under the License. */
...
@@ -14,6 +14,12 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
DEFINE_bool
(
benchmark
,
false
,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
381f2015
...
@@ -1099,10 +1099,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1099,10 +1099,6 @@ All parameter, weight, gradient are variables in Paddle.
"is_distribution"
,
"is_distribution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
is_distribution_
=
b
;
})
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
is_distribution_
=
b
;
})
.
def_property
(
"memory_early_delete"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_early_delete_
=
b
;
})
.
def_property
(
.
def_property
(
"enable_inplace"
,
"enable_inplace"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
381f2015
...
@@ -148,6 +148,7 @@ class ParallelExecutor(object):
...
@@ -148,6 +148,7 @@ class ParallelExecutor(object):
else
framework
.
default_main_program
()
else
framework
.
default_main_program
()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# if turn on python memory optimize, turn off the inplace_pass.
if
build_strategy
.
enable_inplace
is
None
:
build_strategy
.
enable_inplace
=
False
if
main
.
_is_mem_optimized
else
True
build_strategy
.
enable_inplace
=
False
if
main
.
_is_mem_optimized
else
True
scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录