Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
eb825246
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eb825246
编写于
12月 07, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
add unittest model containing while_op remove unnecessary codes test=develop
上级
8a31b2eb
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
516 addition
and
268 deletion
+516
-268
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+3
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+3
-2
paddle/fluid/framework/details/eager_deletion_op_handle.cc
paddle/fluid/framework/details/eager_deletion_op_handle.cc
+25
-23
paddle/fluid/framework/details/eager_deletion_op_handle.h
paddle/fluid/framework/details/eager_deletion_op_handle.h
+4
-4
paddle/fluid/framework/details/eager_deletion_pass.cc
paddle/fluid/framework/details/eager_deletion_pass.cc
+11
-7
paddle/fluid/framework/details/op_graph_view.cc
paddle/fluid/framework/details/op_graph_view.cc
+1
-0
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+93
-63
paddle/fluid/framework/details/reference_count_pass_helper.cc
...le/fluid/framework/details/reference_count_pass_helper.cc
+21
-0
paddle/fluid/framework/details/reference_count_pass_helper.h
paddle/fluid/framework/details/reference_count_pass_helper.h
+2
-2
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+2
-19
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+0
-6
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+31
-25
paddle/fluid/framework/garbage_collector.cc
paddle/fluid/framework/garbage_collector.cc
+89
-0
paddle/fluid/framework/garbage_collector.h
paddle/fluid/framework/garbage_collector.h
+52
-101
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+14
-14
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+1
-1
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+4
-0
python/paddle/fluid/tests/unittests/test_eager_deletion_gru_net.py
...ddle/fluid/tests/unittests/test_eager_deletion_gru_net.py
+49
-0
python/paddle/fluid/tests/unittests/test_eager_deletion_lstm_net.py
...dle/fluid/tests/unittests/test_eager_deletion_lstm_net.py
+111
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
eb825246
...
...
@@ -72,6 +72,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory
)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_library
(
garbage_collector SRCS garbage_collector.cc DEPS device_context memory
)
cc_library
(
reader SRCS reader.cc DEPS lod_tensor ddim
)
cc_test
(
reader_test SRCS reader_test.cc DEPS reader
)
...
...
@@ -164,7 +166,7 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library
(
naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper
garbage_collector
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
eb825246
...
...
@@ -33,9 +33,10 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view
reference_count_pass_helper
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.cc
浏览文件 @
eb825246
...
...
@@ -26,8 +26,8 @@ namespace details {
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
scope_
(
scope
),
var_names_
(
var_names
),
...
...
@@ -35,9 +35,9 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
ref_cnts_
(
ref_cnts
)
{
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
dev_ctx_
=
static
_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx_
=
reinterpret
_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
))
{
if
(
dynamic_cast
<
StreamGarbageCollector
*>
(
gc_
))
{
platform
::
CUDADeviceGuard
guard
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
);
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
...
...
@@ -61,10 +61,11 @@ std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void
EagerDeletionOpHandle
::
RunImpl
()
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
Tensor
*>
tensor
s
;
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
garbage
s
;
for
(
auto
&
name
:
var_names_
)
{
auto
it
=
ref_cnts_
->
find
(
name
);
if
(
it
==
ref_cnts_
->
end
())
{
// Var not found, not reference count has not decreased to 0
if
(
it
==
ref_cnts_
->
end
()
||
it
->
second
.
fetch_sub
(
1
)
!=
1
)
{
continue
;
}
...
...
@@ -73,43 +74,44 @@ void EagerDeletionOpHandle::RunImpl() {
continue
;
}
VLOG
(
2
)
<<
"Erase variable "
<<
name
;
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
it
->
second
.
fetch_sub
(
1
)
==
1
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
}
garbages
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
()
->
MoveMemory
());
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
if
(
it
->
second
.
fetch_sub
(
1
)
==
1
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
garbages
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
()
->
MoveMemory
());
}
else
if
(
var
->
IsType
<
LoDTensorArray
>
())
{
if
(
it
->
second
.
fetch_sub
(
1
)
==
1
)
{
auto
*
tensor_arr
=
var
->
GetMutable
<
LoDTensorArray
>
();
for
(
auto
&
t
:
*
tensor_arr
)
{
tensors
.
emplace_back
(
&
t
);
}
auto
*
tensor_arr
=
var
->
GetMutable
<
LoDTensorArray
>
();
for
(
auto
&
t
:
*
tensor_arr
)
{
garbages
.
emplace_back
(
t
.
MoveMemory
());
}
}
else
{
PADDLE_THROW
(
"Type %s of %s is not supported eager deletion"
,
var
->
Type
().
name
(),
name
);
}
}
if
(
!
tensor
s
.
empty
())
{
Clear
Tensors
(
tensor
s
);
if
(
!
garbage
s
.
empty
())
{
Clear
Garbages
(
&
garbage
s
);
}
}
void
EagerDeletionOpHandle
::
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
)
{
void
EagerDeletionOpHandle
::
ClearGarbages
(
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
*
garbages
)
{
#ifdef PADDLE_WITH_CUDA
if
(
event_
)
{
auto
compute_stream
=
dev_ctx_
->
stream
();
auto
callback_stream
=
static_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
)
->
stream
();
reinterpret_cast
<
StreamGarbageCollector
*>
(
gc_
)
->
stream
();
auto
callback_func
=
[
=
]()
{
PADDLE_ENFORCE
(
cudaEventRecord
(
event_
,
compute_stream
));
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
callback_stream
,
event_
,
0
));
};
gc_
->
Add
(
tensors
,
callback_func
);
gc_
->
Add
(
std
::
move
(
*
garbages
)
,
callback_func
);
}
else
{
#endif
gc_
->
Add
(
tensors
);
gc_
->
Add
(
std
::
move
(
*
garbages
)
);
#ifdef PADDLE_WITH_CUDA
}
#endif
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.h
浏览文件 @
eb825246
...
...
@@ -14,8 +14,8 @@
#pragma once
#include <deque>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
...
...
@@ -30,7 +30,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
GarbageCollector
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
);
~
EagerDeletionOpHandle
();
...
...
@@ -41,11 +41,11 @@ class EagerDeletionOpHandle : public OpHandleBase {
void
RunImpl
()
override
;
private:
void
Clear
Tensors
(
const
std
::
vector
<
Tensor
*>
&
tensor
s
);
void
Clear
Garbages
(
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
*
garbage
s
);
const
Scope
*
scope_
;
std
::
unordered_set
<
std
::
string
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
GarbageCollector
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
#ifdef PADDLE_WITH_CUDA
platform
::
CUDADeviceContext
*
dev_ctx_
{
nullptr
};
...
...
paddle/fluid/framework/details/eager_deletion_pass.cc
浏览文件 @
eb825246
...
...
@@ -28,17 +28,21 @@ namespace details {
std
::
unique_ptr
<
ir
::
Graph
>
EagerDeletionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
kRuntimeReferenceCount
);
PADDLE_ENFORCE
(
ref_cnts
.
empty
(),
"kRuntimeReferenceCount should be initialized here!"
);
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
ref_cnts
.
resize
(
vars
.
size
());
const
auto
&
last_live_ops
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
auto
&
gcs
=
Get
<
GarbageCollectorMap
>
(
kGarbageCollector
);
const
auto
&
gcs
=
Get
<
GarbageCollectorMap
>
(
kGarbageCollector
);
const
auto
&
places
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
kAllPlaces
);
ref_cnts
=
std
::
vector
<
AtomicReferenceCountMap
>
(
vars
.
size
());
// a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted.
std
::
unordered_map
<
ComputationOpHandle
*
,
std
::
unordered_set
<
std
::
string
>>
op_vars_map
;
...
...
@@ -58,8 +62,8 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto
*
eager_deletion_node
=
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
std
::
move
(
var_names
),
gcs
.
at
(
places
[
op
->
GetScopeIdx
()]).
get
(),
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
var_names
,
gcs
.
at
(
places
[
op
->
GetScopeIdx
()]).
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
auto
it
=
std
::
find_if
(
...
...
paddle/fluid/framework/details/op_graph_view.cc
浏览文件 @
eb825246
...
...
@@ -42,6 +42,7 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
std
::
unordered_set
<
OpHandleBase
*>
OpGraphView
::
AllOps
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
ret
;
ret
.
reserve
(
preceding_ops_
.
size
());
for
(
auto
&
pair
:
preceding_ops_
)
{
ret
.
insert
(
pair
.
first
);
}
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
eb825246
...
...
@@ -29,15 +29,17 @@ namespace paddle {
namespace
framework
{
namespace
details
{
class
OpRelationDetector
{
public:
// A functor to shrink/remove operators who depend on other operators in a set
class
ShrinkDepsOpFunctor
{
private:
enum
RelationShip
{
kSame
=
0
,
kNoDeps
=
1
,
kBefore
=
2
,
kAfter
=
3
};
explicit
OpRelationDetector
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
public:
explicit
ShrinkDepsOpFunctor
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
:
graph_
(
all_ops
)
{}
template
<
typename
OpSet
>
OpSet
MaxNoDepOps
(
const
OpSet
&
op_set
)
const
{
OpSet
operator
()
(
const
OpSet
&
op_set
)
const
{
using
KeyType
=
typename
OpSet
::
key_type
;
static_assert
(
std
::
is_base_of
<
OpHandleBase
,
...
...
@@ -51,7 +53,7 @@ class OpRelationDetector {
auto
not_before
=
[](
RelationShip
r
)
{
return
r
!=
kBefore
;
};
for
(
size_t
i
=
0
;
i
<
rels
.
size
();
++
i
)
{
if
(
std
::
all_of
(
rels
[
i
].
begin
(),
rels
[
i
].
end
(),
not_before
))
{
ret
.
insert
(
static_cast
<
KeyType
>
(
ops
[
i
]));
ret
.
emplace
(
static_cast
<
KeyType
>
(
ops
[
i
]));
}
}
return
ret
;
...
...
@@ -59,7 +61,7 @@ class OpRelationDetector {
private:
std
::
vector
<
std
::
vector
<
RelationShip
>>
GetRelations
(
const
std
::
vector
<
OpHandleBase
*>
ops
)
const
{
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
op_to_idx
;
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
graph_
.
HasOp
(
ops
[
i
]),
"Op does not exist in graph"
);
...
...
@@ -112,6 +114,10 @@ class OpRelationDetector {
const
OpGraphView
graph_
;
};
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static
ComputationOpHandle
*
FindNextComputationOpHandleOrReturnItself
(
OpHandleBase
*
op
,
size_t
scope_idx
)
{
std
::
queue
<
OpHandleBase
*>
q
;
...
...
@@ -134,33 +140,87 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
return
nullptr
;
}
static
std
::
unordered_set
<
ComputationOpHandle
*>
ExtractComputationOpFromLastLivedVar
(
VarHandle
*
var
,
size_t
scope_idx
,
const
ShrinkDepsOpFunctor
&
shrink_func
,
bool
*
ok
)
{
// stage one. Get last op for variable.
std
::
unordered_set
<
OpHandleBase
*>
candidates
;
{
if
(
var
->
PendingOps
().
empty
()
&&
var
->
GeneratedOp
())
{
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates
.
emplace
(
var
->
GeneratedOp
());
}
else
{
candidates
=
var
->
PendingOps
();
}
// No pending ops or generated op is nullptr
if
(
candidates
.
empty
())
{
*
ok
=
false
;
return
{};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std
::
unordered_set
<
ComputationOpHandle
*>
computation_op
;
{
for
(
auto
*
op
:
candidates
)
{
auto
*
compute_op
=
FindNextComputationOpHandleOrReturnItself
(
op
,
scope_idx
);
if
(
compute_op
==
nullptr
)
{
*
ok
=
false
;
return
{};
}
computation_op
.
emplace
(
compute_op
);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*
ok
=
true
;
return
shrink_func
(
computation_op
);
}
static
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
VarHandle
*>
&
vars
)
{
VarDesc
*
var_desc
=
nullptr
;
std
::
find_if
(
vars
.
rbegin
(),
vars
.
rend
(),
[
&
](
VarHandle
*
var_handle
)
->
bool
{
var_desc
=
var_handle
->
Node
()
->
Var
();
return
var_desc
!=
nullptr
;
});
return
var_desc
;
}
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
ReferenceCountMap
>>
(
kGlobalReferenceCount
);
auto
&
last_live_ops_of_vars
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
last_live_ops_of_vars
=
std
::
vector
<
LastLiveOpsOfVars
>
(
vars
.
size
());
ref_cnts
=
std
::
vector
<
ReferenceCountMap
>
(
vars
.
size
());
PADDLE_ENFORCE
(
last_live_ops_of_vars
.
empty
()
&&
ref_cnts
.
empty
(),
"Last Live Ops and Reference Counts of vars should be "
"initialized at here."
);
OpRelationDetector
detector
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
)
);
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
if
(
name_var_pair
.
second
.
empty
())
{
continue
;
}
last_live_ops_of_vars
.
resize
(
vars
.
size
());
ref_cnts
.
resize
(
vars
.
size
());
const
std
::
string
&
var_name
=
name_var_pair
.
first
;
auto
*
last_ver_var
=
name_var_pair
.
second
.
back
(
);
ShrinkDepsOpFunctor
shrink_func
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
)
);
VarDesc
*
var_desc
=
nullptr
;
std
::
find_if
(
name_var_pair
.
second
.
rbegin
(),
name_var_pair
.
second
.
rend
(),
[
&
](
VarHandle
*
var_handle
)
->
bool
{
var_desc
=
var_handle
->
Node
()
->
Var
();
return
var_desc
!=
nullptr
;
});
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
// Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies.
VarDesc
*
var_desc
=
TryGetLatestVarDesc
(
name_var_pair
.
second
);
if
(
var_desc
==
nullptr
||
var_desc
->
Persistable
())
{
continue
;
...
...
@@ -170,50 +230,20 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if
(
var_type
!=
proto
::
VarType
::
LOD_TENSOR
&&
var_type
!=
proto
::
VarType
::
SELECTED_ROWS
&&
var_type
!=
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
// Var type cannot be deleted
continue
;
}
std
::
unordered_set
<
ComputationOpHandle
*>
last_live_op
;
auto
add_last_live_op
=
[
&
](
OpHandleBase
*
op
)
->
bool
{
auto
*
compute_op
=
FindNextComputationOpHandleOrReturnItself
(
op
,
i
);
if
(
compute_op
)
{
last_live_op
.
insert
(
compute_op
);
return
true
;
}
else
{
return
false
;
}
};
bool
can_delete
=
false
;
auto
&
pending_ops
=
last_ver_var
->
PendingOps
();
if
(
pending_ops
.
empty
())
{
auto
*
generated_op
=
last_ver_var
->
GeneratedOp
();
if
(
generated_op
&&
add_last_live_op
(
generated_op
))
{
can_delete
=
true
;
}
}
else
{
can_delete
=
true
;
for
(
auto
*
pending_op
:
pending_ops
)
{
if
(
!
add_last_live_op
(
pending_op
))
{
can_delete
=
false
;
break
;
}
}
}
if
(
can_delete
)
{
size_t
original_size
=
last_live_op
.
size
();
last_live_op
=
detector
.
MaxNoDepOps
(
last_live_op
);
if
(
last_live_op
.
size
()
!=
original_size
)
{
VLOG
(
10
)
<<
"Shrink last living op number of "
<<
var_name
<<
" from "
<<
original_size
<<
" to "
<<
last_live_op
.
size
();
}
PADDLE_ENFORCE
(
!
last_live_op
.
empty
(),
"Last living ops of %s cannot be empty"
,
var_name
);
bool
ok
;
auto
result
=
ExtractComputationOpFromLastLivedVar
(
name_var_pair
.
second
.
back
(),
i
,
shrink_func
,
&
ok
);
ref_cnts
[
i
].
emplace
(
var_name
,
last_live_op
.
size
());
last_live_ops_of_vars
[
i
].
emplace
(
var_name
,
std
::
move
(
last_live_op
));
if
(
ok
)
{
auto
&
var_name
=
name_var_pair
.
first
;
PADDLE_ENFORCE
(
!
result
.
empty
(),
"Last living ops of %s cannot be empty"
,
var_name
);
ref_cnts
[
i
].
emplace
(
var_name
,
result
.
size
());
last_live_ops_of_vars
[
i
].
emplace
(
var_name
,
std
::
move
(
result
));
}
}
}
...
...
paddle/fluid/framework/details/reference_count_pass_helper.cc
0 → 100644
浏览文件 @
eb825246
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/reference_count_pass_helper.h
浏览文件 @
eb825246
...
...
@@ -18,10 +18,10 @@
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -35,7 +35,7 @@ using AtomicReferenceCountMap =
std
::
unordered_map
<
std
::
string
,
std
::
atomic
<
size_t
>>
;
using
GarbageCollectorMap
=
std
::
map
<
platform
::
Place
,
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>
>>
;
std
::
map
<
platform
::
Place
,
std
::
unique_ptr
<
GarbageCollector
>>
;
const
char
kGlobalReferenceCount
[]
=
"global_reference_count"
;
const
char
kRuntimeReferenceCount
[]
=
"runtime_reference_count"
;
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
eb825246
...
...
@@ -30,20 +30,7 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
underlying_executor_
(
std
::
move
(
underlying_executor
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
var_infos_
(
std
::
move
(
var_infos
)),
places_
(
std
::
move
(
places
))
{
if
(
Graph
().
Has
(
details
::
kGarbageCollector
))
{
gc_
=
&
(
Graph
().
Get
<
GarbageCollectorMap
>
(
details
::
kGarbageCollector
));
}
}
void
ScopeBufferedSSAGraphExecutor
::
WaitAllGarbageCollectors
()
{
if
(
gc_
)
{
for
(
auto
&
gc_pair
:
*
gc_
)
{
gc_pair
.
second
->
Wait
();
gc_pair
.
second
->
Reset
();
}
}
}
places_
(
std
::
move
(
places
))
{}
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
...
...
@@ -83,19 +70,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
=
0
;
// Wait All computational streams
for
(
auto
&
p
:
places_
)
{
for
(
auto
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
WaitAllGarbageCollectors
();
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
scope
->
DeleteScope
(
local_scope
);
}
}
else
{
WaitAllGarbageCollectors
();
}
if
(
eptr
)
{
std
::
rethrow_exception
(
eptr
);
}
else
{
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
eb825246
...
...
@@ -21,11 +21,9 @@
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -50,8 +48,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
void
WaitAllGarbageCollectors
();
size_t
drop_scope_counter_
{
0
};
ExecutionStrategy
strategy_
;
...
...
@@ -59,8 +55,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
platform
::
Place
>
places_
;
GarbageCollectorMap
*
gc_
{
nullptr
};
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
eb825246
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
...
...
@@ -83,31 +84,37 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
}
static
void
DeleteUnusedTensors
(
const
Scope
&
scope
,
const
OperatorBase
*
op
,
GarbageCollector
<
Tensor
>
*
gc
,
const
Scope
&
scope
,
const
OperatorBase
*
op
,
GarbageCollector
*
gc
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
ref_cnts
)
{
std
::
unordered_set
<
Tensor
*>
erase_tensor
s
;
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
garbage
s
;
auto
handler
=
[
&
](
const
VariableNameMap
&
name_map
)
{
for
(
auto
&
name_pair
:
name_map
)
{
for
(
auto
&
name
:
name_pair
.
second
)
{
auto
it
=
ref_cnts
->
find
(
name
);
if
(
it
==
ref_cnts
->
end
())
continue
;
if
(
--
(
it
->
second
)
==
0
)
{
auto
*
var
=
scope
.
FindVar
(
name
);
if
(
var
!=
nullptr
)
{
VLOG
(
2
)
<<
"Erase tensor
\'
"
<<
name
<<
"
\'
"
;
if
(
var
->
IsType
<
LoDTensor
>
())
{
erase_tensors
.
insert
(
var
->
GetMutable
<
LoDTensor
>
());
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
erase_tensors
.
insert
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
else
if
(
var
->
IsType
<
LoDTensorArray
>
())
{
auto
*
lod_tensor_arr
=
var
->
GetMutable
<
LoDTensorArray
>
();
for
(
auto
&
t
:
*
lod_tensor_arr
)
{
erase_tensors
.
insert
(
&
t
);
}
}
if
(
--
(
it
->
second
)
!=
0
)
{
continue
;
}
auto
*
var
=
scope
.
FindVar
(
name
);
if
(
var
!=
nullptr
)
{
continue
;
}
VLOG
(
2
)
<<
"Erase variable "
<<
name
;
if
(
var
->
IsType
<
LoDTensor
>
())
{
garbages
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
()
->
MoveMemory
());
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
garbages
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
()
->
MoveMemory
());
}
else
if
(
var
->
IsType
<
LoDTensorArray
>
())
{
auto
*
lod_tensor_arr
=
var
->
GetMutable
<
LoDTensorArray
>
();
for
(
auto
&
t
:
*
lod_tensor_arr
)
{
garbages
.
emplace_back
(
t
.
MoveMemory
());
}
}
else
{
PADDLE_THROW
(
"Type %s of %s is not supported eager deletion"
,
var
->
Type
().
name
(),
name
);
}
}
}
...
...
@@ -116,8 +123,8 @@ static void DeleteUnusedTensors(
handler
(
op
->
Inputs
());
handler
(
op
->
Outputs
());
if
(
!
erase_tensor
s
.
empty
())
{
gc
->
Add
(
erase_tensors
);
if
(
!
garbage
s
.
empty
())
{
gc
->
Add
(
std
::
move
(
garbages
)
);
}
}
...
...
@@ -411,22 +418,22 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>
>
gc
;
std
::
unique_ptr
<
GarbageCollector
>
gc
;
if
(
max_memory_size
>=
0
)
{
ctx
->
ResetReferenceCount
();
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
gc
.
reset
(
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
gc
.
reset
(
new
UnsafeFastGPUGarbageCollector
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
max_memory_size
));
}
else
{
gc
.
reset
(
new
DefaultStreamGarbageCollector
<
Tensor
>
(
gc
.
reset
(
new
DefaultStreamGarbageCollector
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
max_memory_size
));
}
}
else
if
(
platform
::
is_cpu_place
(
place_
))
{
#endif
gc
.
reset
(
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place_
),
max_memory_size
));
gc
.
reset
(
new
CPUGarbageCollector
(
boost
::
get
<
platform
::
CPUPlace
>
(
place_
),
max_memory_size
));
#ifdef PADDLE_WITH_CUDA
}
#endif
...
...
@@ -442,7 +449,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
)
->
Wait
();
if
(
gc
)
gc
->
Wait
();
if
(
local_scope
!=
scope
)
{
scope
->
DeleteScope
(
local_scope
);
...
...
paddle/fluid/framework/garbage_collector.cc
0 → 100644
浏览文件 @
eb825246
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/framework/garbage_collector.h"
namespace
paddle
{
namespace
framework
{
GarbageCollector
::
GarbageCollector
(
const
platform
::
Place
&
place
,
size_t
max_memory_size
)
:
max_memory_size_
((
std
::
max
)(
max_memory_size
,
static_cast
<
size_t
>
(
1
)))
{
garbages_
.
reset
(
new
GarbageQueue
());
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
}
CPUGarbageCollector
::
CPUGarbageCollector
(
const
platform
::
CPUPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
(
place
,
max_memory_size
)
{}
void
CPUGarbageCollector
::
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
{
callback
();
}
#ifdef PADDLE_WITH_CUDA
UnsafeFastGPUGarbageCollector
::
UnsafeFastGPUGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
(
place
,
max_memory_size
)
{}
void
UnsafeFastGPUGarbageCollector
::
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
{
callback
();
}
DefaultStreamGarbageCollector
::
DefaultStreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
(
place
,
max_memory_size
)
{}
void
DefaultStreamGarbageCollector
::
Wait
()
const
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
WaitStreamCallback
();
}
void
DefaultStreamGarbageCollector
::
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
AddStreamCallback
(
callback
);
}
StreamGarbageCollector
::
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
(
place
,
max_memory_size
)
{
platform
::
CUDADeviceGuard
guard
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
}
StreamGarbageCollector
::~
StreamGarbageCollector
()
{
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
platform
::
CUDADeviceGuard
guard
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
cudaStream_t
StreamGarbageCollector
::
stream
()
const
{
return
stream_
;
}
void
StreamGarbageCollector
::
Wait
()
const
{
callback_manager_
->
Wait
();
}
void
StreamGarbageCollector
::
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
{
callback_manager_
->
AddCallback
(
callback
);
}
#endif
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/garbage_collector.h
浏览文件 @
eb825246
...
...
@@ -14,160 +14,83 @@
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
framework
{
// T should have memory_size() and clear() method
template
<
typename
T
>
class
GarbageCollector
{
public:
GarbageCollector
(
const
platform
::
Place
&
place
,
size_t
max_memory_size
)
:
max_memory_size_
((
std
::
max
)(
max_memory_size
,
static_cast
<
size_t
>
(
1
)))
{
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
}
using
GarbageQueue
=
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
;
virtual
~
GarbageCollector
()
{}
GarbageCollector
(
const
platform
::
Place
&
place
,
size_t
max_memory_size
);
size_t
NumOfGarbages
()
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
return
garbages_
->
size
();
}
virtual
~
GarbageCollector
()
=
default
;
void
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
cur_memory_size_
=
0
;
}
virtual
void
Wait
()
const
{}
template
<
typename
Container
>
void
Add
(
const
Container
&
objs
)
{
Add
(
objs
,
[]()
{});
}
void
Add
(
Container
&&
objs
);
template
<
typename
Container
,
typename
Callback
>
void
Add
(
const
Container
&
objs
,
Callback
&&
callback
)
{
std
::
deque
<
T
*>
*
clear_deque
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
for
(
auto
*
obj
:
objs
)
{
garbages_
->
push_back
(
obj
);
cur_memory_size_
+=
obj
->
memory_size
();
}
if
(
cur_memory_size_
>=
max_memory_size_
)
{
cur_memory_size_
=
0
;
clear_deque
=
garbages_
.
release
();
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
}
}
if
(
clear_deque
!=
nullptr
)
{
callback
();
ClearCallback
([
clear_deque
]()
{
for
(
auto
*
obj
:
*
clear_deque
)
obj
->
clear
();
delete
clear_deque
;
});
}
}
virtual
void
Wait
()
const
{}
void
Add
(
Container
&&
objs
,
Callback
&&
callback
);
protected:
virtual
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
=
0
;
platform
::
DeviceContext
*
dev_ctx_
;
std
::
unique_ptr
<
std
::
deque
<
T
*>
>
garbages_
;
std
::
unique_ptr
<
GarbageQueue
>
garbages_
;
mutable
std
::
mutex
mutex_
;
const
size_t
max_memory_size_
;
size_t
cur_memory_size_
=
0
;
size_t
cur_memory_size_
{
0
}
;
};
template
<
typename
T
>
class
CPUGarbageCollector
:
public
GarbageCollector
<
T
>
{
class
CPUGarbageCollector
:
public
GarbageCollector
{
public:
CPUGarbageCollector
(
const
platform
::
CPUPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{}
CPUGarbageCollector
(
const
platform
::
CPUPlace
&
place
,
size_t
max_memory_size
);
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
callback
();
}
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
};
#ifdef PADDLE_WITH_CUDA
template
<
typename
T
>
class
UnsafeFastGPUGarbageCollector
:
public
GarbageCollector
<
T
>
{
class
UnsafeFastGPUGarbageCollector
:
public
GarbageCollector
{
public:
UnsafeFastGPUGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{}
size_t
max_memory_size
);
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
callback
();
}
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
};
template
<
typename
T
>
class
DefaultStreamGarbageCollector
:
public
GarbageCollector
<
T
>
{
class
DefaultStreamGarbageCollector
:
public
GarbageCollector
{
public:
DefaultStreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{}
size_t
max_memory_size
);
cudaStream_t
stream
()
const
{
return
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
stream
();
}
void
Wait
()
const
override
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
WaitStreamCallback
();
}
void
Wait
()
const
override
;
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
AddStreamCallback
(
callback
);
}
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
};
template
<
typename
T
>
class
StreamGarbageCollector
:
public
GarbageCollector
<
T
>
{
class
StreamGarbageCollector
:
public
GarbageCollector
{
public:
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{
platform
::
CUDADeviceGuard
guard
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
}
size_t
max_memory_size
);
~
StreamGarbageCollector
()
{
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
platform
::
CUDADeviceGuard
guard
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
~
StreamGarbageCollector
();
void
Wait
()
const
override
{
callback_manager_
->
Wait
();
}
void
Wait
()
const
override
;
cudaStream_t
stream
()
const
{
return
stream_
;
}
cudaStream_t
stream
()
const
;
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
callback_manager_
->
AddCallback
(
callback
);
}
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
private:
cudaStream_t
stream_
;
...
...
@@ -175,5 +98,33 @@ class StreamGarbageCollector : public GarbageCollector<T> {
};
#endif
template
<
typename
Container
>
void
GarbageCollector
::
Add
(
Container
&&
objs
)
{
Add
(
std
::
forward
<
Container
>
(
objs
),
[]()
{});
}
template
<
typename
Container
,
typename
Callback
>
void
GarbageCollector
::
Add
(
Container
&&
objs
,
Callback
&&
callback
)
{
GarbageQueue
*
garbage_queue
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
for
(
auto
&
obj
:
objs
)
{
if
(
!
obj
)
continue
;
cur_memory_size_
+=
obj
->
size
();
garbages_
->
push_back
(
std
::
move
(
obj
));
}
if
(
cur_memory_size_
>=
max_memory_size_
)
{
cur_memory_size_
=
0
;
garbage_queue
=
garbages_
.
release
();
garbages_
.
reset
(
new
GarbageQueue
());
}
}
if
(
garbage_queue
)
{
callback
();
ClearCallback
([
garbage_queue
]()
{
delete
garbage_queue
;
});
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
eb825246
...
...
@@ -97,29 +97,31 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
if
(
gcs_
.
count
(
place
)
>
0
)
{
continue
;
}
GarbageCollector
<
Tensor
>
*
gc
=
nullptr
;
std
::
unique_ptr
<
GarbageCollector
>
gc
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
gc
=
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
);
gc
.
reset
(
new
UnsafeFastGPUGarbageCollector
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
)
)
;
}
else
{
gc
=
new
StreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
);
gc
.
reset
(
new
StreamGarbageCollector
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
)
)
;
}
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
}
else
{
#endif
gc
=
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
);
VLOG
(
10
)
<<
"Created GarbageCollector at "
<<
place
;
if
(
platform
::
is_cpu_place
(
place
))
{
gc
.
reset
(
new
CPUGarbageCollector
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
));
VLOG
(
10
)
<<
"Created GarbageCollector at "
<<
place
;
}
else
{
PADDLE_THROW
(
"Unsupported place for garbage collection"
);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
if
(
gc
)
{
gcs_
[
place
]
=
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>
(
gc
);
}
gcs_
.
emplace
(
place
,
std
::
move
(
gc
));
}
if
(
!
gcs_
.
empty
())
{
...
...
@@ -144,8 +146,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
graph
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
}
return
graph
;
...
...
paddle/fluid/framework/scope.cc
浏览文件 @
eb825246
...
...
@@ -38,7 +38,7 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"
);
DEFINE_bool
(
fast_eager_deletion_mode
,
tru
e
,
DEFINE_bool
(
fast_eager_deletion_mode
,
fals
e
,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends."
);
...
...
paddle/fluid/framework/tensor.h
浏览文件 @
eb825246
...
...
@@ -158,6 +158,10 @@ class Tensor {
const
std
::
shared_ptr
<
memory
::
Allocation
>&
Holder
()
const
{
return
holder_
;
}
size_t
offset
()
const
{
return
offset_
;
}
std
::
shared_ptr
<
memory
::
Allocation
>
MoveMemory
()
{
return
std
::
move
(
holder_
);
}
private:
/*! holds the memory block if allocated. */
std
::
shared_ptr
<
memory
::
Allocation
>
holder_
;
...
...
python/paddle/fluid/tests/unittests/test_eager_deletion_gru_net.py
0 → 100644
浏览文件 @
eb825246
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
test_eager_deletion_lstm_net
import
TestBase
import
paddle.fluid
as
fluid
def
gru_net
(
data
,
label
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
hid_dim2
=
96
,
class_dim
=
2
,
emb_lr
=
400.0
):
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
learning_rate
=
emb_lr
))
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
3
)
gru_h
=
fluid
.
layers
.
dynamic_gru
(
input
=
fc0
,
size
=
hid_dim
,
is_reverse
=
False
)
gru_max
=
fluid
.
layers
.
sequence_pool
(
input
=
gru_h
,
pool_type
=
'max'
)
gru_max_tanh
=
fluid
.
layers
.
tanh
(
gru_max
)
fc1
=
fluid
.
layers
.
fc
(
input
=
gru_max_tanh
,
size
=
hid_dim2
,
act
=
'tanh'
)
prediction
=
fluid
.
layers
.
fc
(
input
=
fc1
,
size
=
class_dim
,
act
=
'softmax'
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
return
avg_cost
class
GRUTest
(
TestBase
):
def
setUp
(
self
):
self
.
net
=
gru_net
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_eager_deletion_lstm_net.py
0 → 100644
浏览文件 @
eb825246
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
os
.
environ
[
'CPU_NUM'
]
=
'2'
import
six
import
unittest
import
paddle
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
def
train
(
network
,
use_cuda
,
use_parallel_executor
,
batch_size
=
32
,
pass_num
=
2
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
print
(
'Skip use_cuda=True because Paddle is not compiled with cuda'
)
return
word_dict
=
paddle
.
dataset
.
imdb
.
word_dict
()
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
imdb
.
train
(
word_dict
),
batch_size
=
batch_size
)
data
=
fluid
.
layers
.
data
(
name
=
"words"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
cost
=
network
(
data
,
label
,
len
(
word_dict
))
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
0.2
)
optimizer
.
minimize
(
cost
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
data
,
label
],
place
=
place
)
reader
=
feeder
.
decorate_reader
(
train_reader
,
multi_devices
=
use_parallel_executor
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
if
use_parallel_executor
:
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
use_cuda
,
loss_name
=
cost
.
name
)
fetch_list
=
[
cost
.
name
]
else
:
train_exe
=
exe
fetch_list
=
[
cost
]
for
pass_id
in
six
.
moves
.
xrange
(
pass_num
):
batch_id
=
0
for
data
in
reader
():
train_exe
.
run
(
feed
=
data
,
fetch_list
=
fetch_list
if
batch_id
%
4
==
0
else
[])
batch_id
+=
1
if
batch_id
>
16
:
break
def
lstm_net
(
data
,
label
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
hid_dim2
=
96
,
class_dim
=
2
,
emb_lr
=
30.0
):
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
learning_rate
=
emb_lr
))
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
)
lstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
input
=
fc0
,
size
=
hid_dim
*
4
,
is_reverse
=
False
)
lstm_max
=
fluid
.
layers
.
sequence_pool
(
input
=
lstm_h
,
pool_type
=
'max'
)
lstm_max_tanh
=
fluid
.
layers
.
tanh
(
lstm_max
)
fc1
=
fluid
.
layers
.
fc
(
input
=
lstm_max_tanh
,
size
=
hid_dim2
,
act
=
'tanh'
)
prediction
=
fluid
.
layers
.
fc
(
input
=
fc1
,
size
=
class_dim
,
act
=
'softmax'
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
return
avg_cost
class
TestBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
net
=
lstm_net
def
test_network
(
self
):
for
use_cuda
in
[
True
,
False
]:
for
use_parallel_executor
in
[
False
,
True
]:
print
(
'network: {}, use_cuda: {}, use_parallel_executor: {}'
.
format
(
self
.
net
.
__name__
,
use_cuda
,
use_parallel_executor
))
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
with
fluid
.
scope_guard
(
core
.
Scope
()):
train
(
self
.
net
,
use_cuda
,
use_parallel_executor
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录