Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c47c451a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c47c451a
编写于
12月 03, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
096673f6
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
458 addition
and
228 deletion
+458
-228
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+2
-0
paddle/fluid/framework/details/eager_deletion_op_handle.cc
paddle/fluid/framework/details/eager_deletion_op_handle.cc
+12
-11
paddle/fluid/framework/details/eager_deletion_op_handle.h
paddle/fluid/framework/details/eager_deletion_op_handle.h
+1
-7
paddle/fluid/framework/details/eager_deletion_pass.cc
paddle/fluid/framework/details/eager_deletion_pass.cc
+40
-41
paddle/fluid/framework/details/op_graph_view.h
paddle/fluid/framework/details/op_graph_view.h
+28
-1
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+116
-9
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+15
-6
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+2
-0
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+78
-26
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+12
-39
paddle/fluid/framework/garbage_collector.h
paddle/fluid/framework/garbage_collector.h
+29
-15
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+10
-3
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+6
-0
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+1
-0
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+1
-1
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+43
-1
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+6
-6
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+2
-8
paddle/fluid/platform/stream_callback_manager.cc
paddle/fluid/platform/stream_callback_manager.cc
+31
-36
paddle/fluid/platform/stream_callback_manager.h
paddle/fluid/platform/stream_callback_manager.h
+11
-9
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+6
-6
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+3
-2
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
c47c451a
...
@@ -35,7 +35,7 @@ cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_e
...
@@ -35,7 +35,7 @@ cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_e
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass
op_graph_view
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
...
...
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
c47c451a
...
@@ -31,6 +31,8 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
...
@@ -31,6 +31,8 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void
ComputationOpHandle
::
RunImpl
()
{
void
ComputationOpHandle
::
RunImpl
()
{
WaitInputVarGenerated
(
place_
);
WaitInputVarGenerated
(
place_
);
VLOG
(
10
)
<<
"Run Op"
<<
Name
();
auto
run_func
=
[
this
]()
{
auto
run_func
=
[
this
]()
{
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
};
};
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.cc
浏览文件 @
c47c451a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -23,28 +24,32 @@ namespace details {
...
@@ -23,28 +24,32 @@ namespace details {
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
AtomicReferenceCountMap
*
ref_cnts
)
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
scope_
(
scope
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
:
OpHandleBase
(
node
),
scope_
(
scope
),
var_names_
(
var_names
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
platform
::
is_gpu_place
(
place
))
{
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
))
{
if
(
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
))
{
platform
::
SetDeviceId
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
);
platform
::
CUDADeviceGuard
guard
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
);
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
PADDLE_ENFORCE_NOT_NULL
(
event_
);
}
}
}
}
#endif
#endif
for
(
auto
&
name
:
var_names
)
AddVar
(
name
);
}
}
EagerDeletionOpHandle
::~
EagerDeletionOpHandle
()
{
EagerDeletionOpHandle
::~
EagerDeletionOpHandle
()
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
event_
)
{
if
(
event_
)
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
platform
::
SetDeviceI
d
(
gpu_place
.
device
);
platform
::
CUDADeviceGuard
guar
d
(
gpu_place
.
device
);
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
}
}
#endif
#endif
...
@@ -52,10 +57,6 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
...
@@ -52,10 +57,6 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std
::
string
EagerDeletionOpHandle
::
Name
()
const
{
return
"eager_deletion"
;
}
std
::
string
EagerDeletionOpHandle
::
Name
()
const
{
return
"eager_deletion"
;
}
void
EagerDeletionOpHandle
::
AddVar
(
const
std
::
string
&
name
)
{
var_names_
.
insert
(
name
);
}
void
EagerDeletionOpHandle
::
RunImpl
()
{
void
EagerDeletionOpHandle
::
RunImpl
()
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
Tensor
*>
tensors
;
std
::
vector
<
Tensor
*>
tensors
;
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.h
浏览文件 @
c47c451a
...
@@ -25,13 +25,11 @@ class Scope;
...
@@ -25,13 +25,11 @@ class Scope;
namespace
details
{
namespace
details
{
class
EagerDeletionPass
;
class
EagerDeletionOpHandle
:
public
OpHandleBase
{
class
EagerDeletionOpHandle
:
public
OpHandleBase
{
public:
public:
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
);
AtomicReferenceCountMap
*
ref_cnts
);
...
@@ -45,8 +43,6 @@ class EagerDeletionOpHandle : public OpHandleBase {
...
@@ -45,8 +43,6 @@ class EagerDeletionOpHandle : public OpHandleBase {
private:
private:
void
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
);
void
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
);
void
AddVar
(
const
std
::
string
&
name
);
const
Scope
*
scope_
;
const
Scope
*
scope_
;
std
::
unordered_set
<
std
::
string
>
var_names_
;
std
::
unordered_set
<
std
::
string
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
...
@@ -55,8 +51,6 @@ class EagerDeletionOpHandle : public OpHandleBase {
...
@@ -55,8 +51,6 @@ class EagerDeletionOpHandle : public OpHandleBase {
platform
::
CUDADeviceContext
*
dev_ctx_
{
nullptr
};
platform
::
CUDADeviceContext
*
dev_ctx_
{
nullptr
};
cudaEvent_t
event_
{
nullptr
};
cudaEvent_t
event_
{
nullptr
};
#endif
#endif
friend
class
EagerDeletionPass
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/eager_deletion_pass.cc
浏览文件 @
c47c451a
...
@@ -26,62 +26,61 @@ namespace paddle {
...
@@ -26,62 +26,61 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
void
AddDependencyBetween
(
OpHandleBase
*
in
,
OpHandleBase
*
out
,
ir
::
Graph
*
graph
)
{
auto
it
=
std
::
find_if
(
in
->
Outputs
().
begin
(),
in
->
Outputs
().
end
(),
[](
VarHandleBase
*
var
)
{
return
dynamic_cast
<
DummyVarHandle
*>
(
var
)
!=
nullptr
;
});
if
(
it
!=
in
->
Outputs
().
end
())
{
out
->
AddInput
(
*
it
);
}
else
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
in
->
AddOutput
(
dep_var
);
out
->
AddInput
(
dep_var
);
}
// Add leaf node to eager_deletion_node
if
(
out
->
Outputs
().
empty
())
{
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
out
->
AddOutput
(
dummy_leaf
);
}
}
std
::
unique_ptr
<
ir
::
Graph
>
EagerDeletionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
EagerDeletionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
ref_cnts
=
auto
&
ref_cnts
=
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
kCurReferenceCount
);
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
kCurReferenceCount
);
auto
&
last_live_ops
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
const
auto
&
last_live_ops
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
auto
&
gcs
=
Get
<
GarbageCollectorList
>
(
kGarbageCollector
);
auto
&
gcs
=
Get
<
GarbageCollectorList
>
(
kGarbageCollector
);
ref_cnts
=
std
::
vector
<
AtomicReferenceCountMap
>
(
vars
.
size
());
ref_cnts
=
std
::
vector
<
AtomicReferenceCountMap
>
(
vars
.
size
());
std
::
unordered_map
<
ComputationOpHandle
*
,
EagerDeletionOpHandle
*>
op_map
;
std
::
unordered_map
<
ComputationOpHandle
*
,
std
::
unordered_set
<
std
::
string
>>
op_vars_map
;
for
(
auto
&
var_ops_map
:
last_live_ops
)
{
for
(
auto
&
var_ops_map
:
last_live_ops
)
{
for
(
auto
&
var_ops_pair
:
var_ops_map
)
{
for
(
auto
&
var_ops_pair
:
var_ops_map
)
{
const
std
::
string
&
var_name
=
var_ops_pair
.
first
;
const
std
::
string
&
var_name
=
var_ops_pair
.
first
;
for
(
ComputationOpHandle
*
op
:
var_ops_pair
.
second
)
{
for
(
auto
*
op
:
var_ops_pair
.
second
)
{
auto
it
=
op_map
.
find
(
op
);
op_vars_map
[
op
].
insert
(
var_name
);
if
(
it
!=
op_map
.
end
())
{
it
->
second
->
AddVar
(
var_name
);
}
else
{
auto
*
eager_deletion_node
=
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
{
var_name
},
gcs
[
op
->
GetScopeIdx
()].
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
AddDependencyBetween
(
op
,
eager_deletion_op
,
graph
.
get
());
op_map
[
op
]
=
eager_deletion_op
;
}
}
}
}
}
}
}
VLOG
(
10
)
<<
"Create "
<<
op_map
.
size
()
<<
" EagerDeletionOpHandle(s)"
;
for
(
auto
&
pair
:
op_vars_map
)
{
auto
*
op
=
pair
.
first
;
auto
&
var_names
=
pair
.
second
;
auto
*
eager_deletion_node
=
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
std
::
move
(
var_names
),
gcs
[
op
->
GetScopeIdx
()].
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
auto
it
=
std
::
find_if
(
op
->
Outputs
().
begin
(),
op
->
Outputs
().
end
(),
[](
VarHandleBase
*
var
)
{
return
dynamic_cast
<
DummyVarHandle
*>
(
var
)
!=
nullptr
;
});
if
(
it
!=
op
->
Outputs
().
end
())
{
eager_deletion_op
->
AddInput
(
*
it
);
}
else
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
op
->
AddOutput
(
dep_var
);
eager_deletion_op
->
AddInput
(
dep_var
);
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
eager_deletion_op
->
AddOutput
(
dummy_leaf
);
}
VLOG
(
10
)
<<
"Create "
<<
op_vars_map
.
size
()
<<
" EagerDeletionOpHandle(s)"
;
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/details/op_graph_view.h
浏览文件 @
c47c451a
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#pragma once
#pragma once
#include <
memory
>
#include <
queue
>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
...
@@ -34,6 +34,11 @@ class OpGraphView {
...
@@ -34,6 +34,11 @@ class OpGraphView {
bool
HasOp
(
OpHandleBase
*
op
)
const
;
bool
HasOp
(
OpHandleBase
*
op
)
const
;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template
<
typename
Callback
>
bool
VisitAllPendingOps
(
OpHandleBase
*
op
,
Callback
&&
callback
)
const
;
private:
private:
void
Build
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
);
void
Build
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
);
void
EnforceHasOp
(
OpHandleBase
*
op
)
const
;
void
EnforceHasOp
(
OpHandleBase
*
op
)
const
;
...
@@ -44,6 +49,28 @@ class OpGraphView {
...
@@ -44,6 +49,28 @@ class OpGraphView {
pending_ops_
;
pending_ops_
;
};
};
template
<
typename
Callback
>
bool
OpGraphView
::
VisitAllPendingOps
(
OpHandleBase
*
op
,
Callback
&&
callback
)
const
{
EnforceHasOp
(
op
);
std
::
unordered_set
<
OpHandleBase
*>
visited
;
std
::
queue
<
OpHandleBase
*>
q
;
q
.
push
(
op
);
do
{
op
=
q
.
front
();
q
.
pop
();
for
(
auto
&
pending_op
:
pending_ops_
.
at
(
op
))
{
if
(
visited
.
count
(
pending_op
)
==
0
)
{
visited
.
insert
(
pending_op
);
if
(
!
callback
(
pending_op
))
{
return
false
;
}
}
}
}
while
(
!
q
.
empty
());
return
true
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
c47c451a
...
@@ -14,11 +14,13 @@
...
@@ -14,11 +14,13 @@
#include <queue>
#include <queue>
#include <string>
#include <string>
#include <type_traits>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
@@ -27,6 +29,89 @@ namespace paddle {
...
@@ -27,6 +29,89 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
struct
OpConnectionDetector
{
public:
enum
RelationShip
{
kSame
=
0
,
kNoDeps
=
1
,
kBefore
=
2
,
kAfter
=
3
};
explicit
OpConnectionDetector
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
:
graph_
(
all_ops
)
{}
template
<
typename
OpSet
>
std
::
unordered_set
<
typename
OpSet
::
key_type
>
MaxNoDepOps
(
const
OpSet
&
op_set
)
{
using
KeyType
=
typename
OpSet
::
key_type
;
static_assert
(
std
::
is_base_of
<
OpHandleBase
,
typename
std
::
remove_pointer
<
KeyType
>::
type
>::
value
,
"Key type of OpSet must be or derived of OpHandleBase"
);
std
::
vector
<
OpHandleBase
*>
ops
(
op_set
.
begin
(),
op_set
.
end
());
std
::
unordered_set
<
KeyType
>
ret
;
auto
rels
=
GetRelations
(
ops
);
auto
not_before
=
[](
RelationShip
r
)
{
return
r
!=
kBefore
;
};
for
(
size_t
i
=
0
;
i
<
rels
.
size
();
++
i
)
{
if
(
std
::
all_of
(
rels
[
i
].
begin
(),
rels
[
i
].
end
(),
not_before
))
{
ret
.
insert
(
static_cast
<
KeyType
>
(
ops
[
i
]));
}
}
return
ret
;
}
private:
std
::
vector
<
std
::
vector
<
RelationShip
>>
GetRelations
(
const
std
::
vector
<
OpHandleBase
*>
ops
)
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
op_to_idx
;
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
graph_
.
HasOp
(
ops
[
i
]),
"Op does not exist in graph"
);
op_to_idx
[
ops
[
i
]]
=
i
;
}
PADDLE_ENFORCE
(
op_to_idx
.
size
()
==
ops
.
size
(),
"Duplicate ops"
);
std
::
vector
<
std
::
vector
<
RelationShip
>>
ret
(
ops
.
size
());
for
(
auto
&
e
:
ret
)
{
e
.
assign
(
ops
.
size
(),
kSame
);
}
size_t
found_num
=
ops
.
size
();
size_t
total_num
=
ops
.
size
()
*
ops
.
size
();
auto
visitor
=
[
&
](
OpHandleBase
*
op
,
size_t
i
)
{
auto
it
=
op_to_idx
.
find
(
op
);
if
(
it
!=
op_to_idx
.
end
())
{
size_t
j
=
it
->
second
;
if
(
ret
[
i
][
j
]
!=
kSame
)
{
ret
[
i
][
j
]
=
kBefore
;
ret
[
j
][
i
]
=
kAfter
;
found_num
+=
2
;
if
(
found_num
==
total_num
)
{
return
false
;
}
}
}
return
true
;
};
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
auto
sub_visitor
=
[
&
,
i
](
OpHandleBase
*
op
)
{
return
visitor
(
op
,
i
);
};
if
(
!
graph_
.
VisitAllPendingOps
(
ops
[
i
],
sub_visitor
))
{
break
;
}
}
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
for
(
size_t
j
=
i
+
1
;
j
<
ops
.
size
();
++
j
)
{
if
(
ret
[
i
][
j
]
!=
kSame
)
continue
;
ret
[
i
][
j
]
=
kNoDeps
;
ret
[
j
][
i
]
=
kNoDeps
;
}
}
return
ret
;
}
const
OpGraphView
graph_
;
};
static
ComputationOpHandle
*
FindNextComputationOpHandleOrReturnItself
(
static
ComputationOpHandle
*
FindNextComputationOpHandleOrReturnItself
(
OpHandleBase
*
op
,
size_t
scope_idx
)
{
OpHandleBase
*
op
,
size_t
scope_idx
)
{
std
::
queue
<
OpHandleBase
*>
q
;
std
::
queue
<
OpHandleBase
*>
q
;
...
@@ -59,9 +144,15 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -59,9 +144,15 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
last_live_ops_of_vars
=
std
::
vector
<
LastLiveOpsOfVars
>
(
vars
.
size
());
last_live_ops_of_vars
=
std
::
vector
<
LastLiveOpsOfVars
>
(
vars
.
size
());
ref_cnts
=
std
::
vector
<
ReferenceCountMap
>
(
vars
.
size
());
ref_cnts
=
std
::
vector
<
ReferenceCountMap
>
(
vars
.
size
());
OpConnectionDetector
detector
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
));
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
if
(
name_var_pair
.
second
.
empty
())
continue
;
if
(
name_var_pair
.
second
.
empty
())
{
continue
;
}
const
std
::
string
&
var_name
=
name_var_pair
.
first
;
auto
*
last_ver_var
=
name_var_pair
.
second
.
back
();
auto
*
last_ver_var
=
name_var_pair
.
second
.
back
();
VarDesc
*
var_desc
=
nullptr
;
VarDesc
*
var_desc
=
nullptr
;
...
@@ -83,30 +174,46 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -83,30 +174,46 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}
}
std
::
unordered_set
<
ComputationOpHandle
*>
last_live_op
;
std
::
unordered_set
<
ComputationOpHandle
*>
last_live_op
;
auto
add_last_live_op
=
[
&
](
OpHandleBase
*
op
)
{
auto
add_last_live_op
=
[
&
](
OpHandleBase
*
op
)
->
bool
{
auto
*
compute_op
=
FindNextComputationOpHandleOrReturnItself
(
op
,
i
);
auto
*
compute_op
=
FindNextComputationOpHandleOrReturnItself
(
op
,
i
);
if
(
compute_op
)
{
if
(
compute_op
)
{
last_live_op
.
insert
(
compute_op
);
last_live_op
.
insert
(
compute_op
);
return
true
;
}
else
{
return
false
;
}
}
};
};
const
std
::
string
&
var_name
=
name_var_pair
.
first
;
bool
can_delete
=
false
;
auto
&
pending_ops
=
last_ver_var
->
PendingOps
();
auto
&
pending_ops
=
last_ver_var
->
PendingOps
();
if
(
pending_ops
.
empty
())
{
if
(
pending_ops
.
empty
())
{
auto
*
generated_op
=
last_ver_var
->
GeneratedOp
();
auto
*
generated_op
=
last_ver_var
->
GeneratedOp
();
if
(
generated_op
)
{
if
(
generated_op
&&
add_last_live_op
(
generated_op
))
{
ref_cnts
[
i
].
emplace
(
var_name
,
1
);
can_delete
=
true
;
add_last_live_op
(
generated_op
);
}
}
}
else
{
}
else
{
ref_cnts
[
i
].
emplace
(
var_name
,
pending_ops
.
size
())
;
can_delete
=
true
;
for
(
auto
*
pending_op
:
pending_ops
)
{
for
(
auto
*
pending_op
:
pending_ops
)
{
add_last_live_op
(
pending_op
);
if
(
!
add_last_live_op
(
pending_op
))
{
can_delete
=
false
;
break
;
}
}
}
}
}
last_live_ops_of_vars
[
i
].
emplace
(
var_name
,
std
::
move
(
last_live_op
));
if
(
can_delete
)
{
size_t
original_size
=
last_live_op
.
size
();
last_live_op
=
detector
.
MaxNoDepOps
(
last_live_op
);
if
(
last_live_op
.
size
()
!=
original_size
)
{
VLOG
(
10
)
<<
"Shrink last living op number of "
<<
var_name
<<
" from "
<<
original_size
<<
" to "
<<
last_live_op
.
size
();
}
ref_cnts
[
i
].
emplace
(
var_name
,
last_live_op
.
size
());
last_live_ops_of_vars
[
i
].
emplace
(
var_name
,
std
::
move
(
last_live_op
));
}
}
}
}
}
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
c47c451a
...
@@ -36,6 +36,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
...
@@ -36,6 +36,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
}
}
}
}
void
ScopeBufferedSSAGraphExecutor
::
WaitAllGarbageCollectors
()
{
if
(
gc_
)
{
for
(
auto
&
gc
:
*
gc_
)
{
gc
->
Wait
();
gc
->
Reset
();
}
}
}
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
if
(
drop_scope_counter_
==
0
)
{
if
(
drop_scope_counter_
==
0
)
{
...
@@ -74,19 +83,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
...
@@ -74,19 +83,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
=
0
;
drop_scope_counter_
=
0
;
// Wait All computational streams
// Wait All computational streams
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
auto
&
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
])
->
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
if
(
gc_
)
{
(
*
gc_
)[
i
]
->
Wait
();
(
*
gc_
)[
i
]
->
Reset
();
}
}
}
WaitAllGarbageCollectors
();
for
(
auto
&
scope
:
local_scopes_
)
{
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
auto
&
local_scope
=
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
scope
->
DeleteScope
(
local_scope
);
scope
->
DeleteScope
(
local_scope
);
}
}
}
else
{
WaitAllGarbageCollectors
();
}
}
if
(
eptr
)
{
if
(
eptr
)
{
std
::
rethrow_exception
(
eptr
);
std
::
rethrow_exception
(
eptr
);
}
else
{
}
else
{
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
c47c451a
...
@@ -50,6 +50,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -50,6 +50,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
private:
void
WaitAllGarbageCollectors
();
size_t
drop_scope_counter_
{
0
};
size_t
drop_scope_counter_
{
0
};
ExecutionStrategy
strategy_
;
ExecutionStrategy
strategy_
;
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
c47c451a
...
@@ -37,11 +37,49 @@ namespace {
...
@@ -37,11 +37,49 @@ namespace {
int
kProgramId
=
-
1
;
int
kProgramId
=
-
1
;
}
// namespace
}
// namespace
static
std
::
unordered_map
<
std
::
string
,
size_t
>
GetNonPersistableReferenceCounts
(
const
BlockDesc
&
block
,
const
std
::
vector
<
std
::
string
>&
skip_var_list
)
{
std
::
unordered_map
<
std
::
string
,
size_t
>
ref_cnts
;
std
::
unordered_set
<
std
::
string
>
skip_vars
(
skip_var_list
.
begin
(),
skip_var_list
.
end
());
auto
update_ref_cnts
=
[
&
](
OpDesc
*
op_desc
,
const
VariableNameMap
&
name_map
)
{
for
(
auto
&
name_pair
:
name_map
)
{
for
(
auto
&
name
:
name_pair
.
second
)
{
if
(
skip_vars
.
count
(
name
))
continue
;
auto
*
var_desc
=
block
.
FindVar
(
name
);
if
(
var_desc
==
nullptr
||
var_desc
->
Persistable
())
continue
;
auto
type
=
var_desc
->
Proto
()
->
type
().
type
();
if
(
type
!=
proto
::
VarType
::
LOD_TENSOR
&&
type
!=
proto
::
VarType
::
SELECTED_ROWS
&&
type
!=
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
continue
;
}
auto
it
=
ref_cnts
.
find
(
name
);
if
(
it
!=
ref_cnts
.
end
())
{
++
it
->
second
;
}
else
{
ref_cnts
[
name
]
=
1
;
}
}
}
};
for
(
auto
op_desc
:
block
.
AllOps
())
{
update_ref_cnts
(
op_desc
,
op_desc
->
Inputs
());
update_ref_cnts
(
op_desc
,
op_desc
->
Outputs
());
}
return
ref_cnts
;
}
ExecutorPrepareContext
::
ExecutorPrepareContext
(
ExecutorPrepareContext
::
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
)
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
)
:
prog_
(
prog
),
block_id_
(
block_id
)
{
:
prog_
(
prog
),
block_id_
(
block_id
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
)
{
ref_cnts_
=
GetNonPersistableReferenceCount
<
int
>
(
prog_
,
block_id_
);
ref_cnts_
=
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
skip_ref_cnt_vars
);
}
}
}
}
...
@@ -49,10 +87,9 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
...
@@ -49,10 +87,9 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG
(
5
)
<<
"destroy ExecutorPrepareContext"
;
VLOG
(
5
)
<<
"destroy ExecutorPrepareContext"
;
}
}
template
<
typename
RefCntMap
>
static
void
DeleteUnusedTensors
(
static
void
DeleteUnusedTensors
(
const
Scope
&
scope
,
const
OperatorBase
*
op
,
const
Scope
&
scope
,
const
OperatorBase
*
op
,
GarbageCollector
<
Tensor
>*
gc
,
GarbageCollector
<
Tensor
>*
gc
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
ref_cnts
)
{
RefCntMap
*
ref_cnts
)
{
std
::
unordered_set
<
Tensor
*>
erase_tensors
;
std
::
unordered_set
<
Tensor
*>
erase_tensors
;
auto
handler
=
[
&
](
const
VariableNameMap
&
name_map
)
{
auto
handler
=
[
&
](
const
VariableNameMap
&
name_map
)
{
...
@@ -60,7 +97,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
...
@@ -60,7 +97,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
for
(
auto
&
name
:
name_pair
.
second
)
{
for
(
auto
&
name
:
name_pair
.
second
)
{
auto
it
=
ref_cnts
->
find
(
name
);
auto
it
=
ref_cnts
->
find
(
name
);
if
(
it
==
ref_cnts
->
end
())
continue
;
if
(
it
==
ref_cnts
->
end
())
continue
;
if
(
(
it
->
second
)
--
==
1
)
{
if
(
--
(
it
->
second
)
==
0
)
{
auto
*
var
=
scope
.
FindVar
(
name
);
auto
*
var
=
scope
.
FindVar
(
name
);
if
(
var
!=
nullptr
)
{
if
(
var
!=
nullptr
)
{
VLOG
(
10
)
<<
"Erase tensor
\'
"
<<
name
<<
"
\'
"
;
VLOG
(
10
)
<<
"Erase tensor
\'
"
<<
name
<<
"
\'
"
;
...
@@ -69,6 +106,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
...
@@ -69,6 +106,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
erase_tensors
.
insert
(
erase_tensors
.
insert
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
else
if
(
var
->
IsType
<
LoDTensorArray
>
())
{
auto
*
lod_tensor_arr
=
var
->
GetMutable
<
LoDTensorArray
>
();
for
(
auto
&
t
:
*
lod_tensor_arr
)
{
erase_tensors
.
insert
(
&
t
);
}
}
}
}
}
}
}
...
@@ -351,9 +393,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
...
@@ -351,9 +393,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
std
::
unique_ptr
<
ExecutorPrepareContext
>
Executor
::
Prepare
(
std
::
unique_ptr
<
ExecutorPrepareContext
>
Executor
::
Prepare
(
const
ProgramDesc
&
program
,
int
block_id
)
{
const
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
)
{
std
::
unique_ptr
<
ExecutorPrepareContext
>
ctx
(
std
::
unique_ptr
<
ExecutorPrepareContext
>
ctx
(
new
ExecutorPrepareContext
(
program
,
block_id
));
new
ExecutorPrepareContext
(
program
,
block_id
,
skip_ref_cnt_vars
));
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
block_id
),
program
.
Size
());
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
block_id
),
program
.
Size
());
auto
&
block
=
program
.
Block
(
block_id
);
auto
&
block
=
program
.
Block
(
block_id
);
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
...
@@ -364,16 +407,28 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
...
@@ -364,16 +407,28 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
}
}
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Executor
::
Prepare
(
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Executor
::
Prepare
(
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
)
{
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
skip_ref_cnt_vars
)
{
PADDLE_ENFORCE
(
skip_ref_cnt_vars
.
empty
()
||
skip_ref_cnt_vars
.
size
()
==
block_ids
.
size
(),
"skip_ref_cnt_vars should be either empty or equals to block number %d"
,
block_ids
.
size
());
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
result
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
result
;
size_t
idx
=
0
;
for
(
auto
&
bid
:
block_ids
)
{
for
(
auto
&
bid
:
block_ids
)
{
auto
*
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
);
ExecutorPrepareContext
*
ctx
;
if
(
skip_ref_cnt_vars
.
empty
())
{
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
);
}
else
{
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
,
skip_ref_cnt_vars
[
idx
]);
}
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
bid
),
program
.
Size
());
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
bid
),
program
.
Size
());
auto
&
block
=
program
.
Block
(
bid
);
auto
&
block
=
program
.
Block
(
bid
);
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
ctx
->
ops_
.
push_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
ctx
->
ops_
.
push_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
}
}
result
.
push_back
(
std
::
shared_ptr
<
ExecutorPrepareContext
>
(
ctx
));
result
.
push_back
(
std
::
shared_ptr
<
ExecutorPrepareContext
>
(
ctx
));
++
idx
;
}
}
return
result
;
return
result
;
}
}
...
@@ -392,18 +447,18 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
...
@@ -392,18 +447,18 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>
gc
;
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>
gc
;
// WhileOp would set keep_kids to true,
if
(
max_memory_size
>=
0
)
{
// because WhileGradOp needs the scopes created in WhileOp.
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
if
(
max_memory_size
>=
0
&&
!
keep_kids
)
{
ctx
->
ResetReferenceCount
();
ctx
->
ResetReferenceCount
();
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
gc
.
reset
(
new
DefaultStreamGarbageCollector
<
Tensor
>
(
if
(
IsFastEagerDeletionModeEnabled
())
{
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
max_memory_size
));
gc
.
reset
(
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
}
else
{
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
max_memory_size
));
}
else
{
gc
.
reset
(
new
DefaultStreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
max_memory_size
));
}
}
else
if
(
platform
::
is_cpu_place
(
place_
))
{
#endif
#endif
gc
.
reset
(
new
CPUGarbageCollector
<
Tensor
>
(
gc
.
reset
(
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place_
),
max_memory_size
));
boost
::
get
<
platform
::
CPUPlace
>
(
place_
),
max_memory_size
));
...
@@ -415,17 +470,14 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
...
@@ -415,17 +470,14 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for
(
auto
&
op
:
ctx
->
ops_
)
{
for
(
auto
&
op
:
ctx
->
ops_
)
{
op
->
Run
(
*
local_scope
,
place_
);
op
->
Run
(
*
local_scope
,
place_
);
if
(
gc
!=
nullptr
)
{
if
(
gc
)
{
DeleteUnusedTensors
(
*
local_scope
,
op
.
get
(),
gc
.
get
(),
DeleteUnusedTensors
(
*
local_scope
,
op
.
get
(),
gc
.
get
(),
&
(
ctx
->
cur_ref_cnts_
));
&
(
ctx
->
cur_ref_cnts_
));
}
}
}
}
if
(
gc
!=
nullptr
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
)
->
Wait
();
gc
->
Wait
();
if
(
gc
)
gc
->
Wait
();
}
else
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
)
->
Wait
();
}
if
(
local_scope
!=
scope
)
{
if
(
local_scope
!=
scope
)
{
scope
->
DeleteScope
(
local_scope
);
scope
->
DeleteScope
(
local_scope
);
...
...
paddle/fluid/framework/executor.h
浏览文件 @
c47c451a
...
@@ -28,42 +28,11 @@ namespace paddle {
...
@@ -28,42 +28,11 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
extern
void
InitializeVariable
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
extern
void
InitializeVariable
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
template
<
typename
T
>
std
::
unordered_map
<
std
::
string
,
T
>
GetNonPersistableReferenceCount
(
const
ProgramDesc
&
prog
,
size_t
block_id
)
{
auto
&
block
=
prog
.
Block
(
block_id
);
std
::
unordered_map
<
std
::
string
,
T
>
ref_cnts
;
auto
update_ref_cnts
=
[
&
](
OpDesc
*
op_desc
,
const
VariableNameMap
&
name_map
)
{
for
(
auto
&
name_pair
:
name_map
)
{
for
(
auto
&
name
:
name_pair
.
second
)
{
auto
*
var_desc
=
block
.
FindVar
(
name
);
if
(
var_desc
==
nullptr
||
var_desc
->
Persistable
())
continue
;
auto
type
=
var_desc
->
Proto
()
->
type
().
type
();
if
(
type
!=
proto
::
VarType
::
LOD_TENSOR
&&
type
!=
proto
::
VarType
::
SELECTED_ROWS
)
{
continue
;
}
auto
it
=
ref_cnts
.
find
(
name
);
if
(
it
!=
ref_cnts
.
end
())
{
++
it
->
second
;
}
else
{
ref_cnts
[
name
]
=
1
;
}
}
}
};
for
(
auto
op_desc
:
block
.
AllOps
())
{
update_ref_cnts
(
op_desc
,
op_desc
->
Inputs
());
update_ref_cnts
(
op_desc
,
op_desc
->
Outputs
());
}
return
ref_cnts
;
}
struct
ExecutorPrepareContext
{
struct
ExecutorPrepareContext
{
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
);
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
string
>
());
~
ExecutorPrepareContext
();
~
ExecutorPrepareContext
();
void
ResetReferenceCount
()
{
cur_ref_cnts_
=
ref_cnts_
;
}
void
ResetReferenceCount
()
{
cur_ref_cnts_
=
ref_cnts_
;
}
...
@@ -72,8 +41,8 @@ struct ExecutorPrepareContext {
...
@@ -72,8 +41,8 @@ struct ExecutorPrepareContext {
size_t
block_id_
;
size_t
block_id_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
in
t
>
ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_
t
>
ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
in
t
>
cur_ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_
t
>
cur_ref_cnts_
;
};
};
class
Executor
{
class
Executor
{
...
@@ -109,10 +78,14 @@ class Executor {
...
@@ -109,10 +78,14 @@ class Executor {
const
std
::
string
&
fetch_holder_name
=
"fetch"
);
const
std
::
string
&
fetch_holder_name
=
"fetch"
);
static
std
::
unique_ptr
<
ExecutorPrepareContext
>
Prepare
(
static
std
::
unique_ptr
<
ExecutorPrepareContext
>
Prepare
(
const
ProgramDesc
&
program
,
int
block_id
);
const
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
string
>
());
static
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Prepare
(
static
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Prepare
(
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
);
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
vector
<
std
::
string
>>
());
void
CreateVariables
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
);
void
CreateVariables
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
);
...
...
paddle/fluid/framework/garbage_collector.h
浏览文件 @
c47c451a
...
@@ -19,6 +19,9 @@
...
@@ -19,6 +19,9 @@
#include <functional>
#include <functional>
#include <memory>
#include <memory>
#include <mutex> // NOLINT
#include <mutex> // NOLINT
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -36,6 +39,11 @@ class GarbageCollector {
...
@@ -36,6 +39,11 @@ class GarbageCollector {
virtual
~
GarbageCollector
()
{}
virtual
~
GarbageCollector
()
{}
size_t
NumOfGarbages
()
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
return
garbages_
->
size
();
}
void
Reset
()
{
void
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
...
@@ -49,7 +57,7 @@ class GarbageCollector {
...
@@ -49,7 +57,7 @@ class GarbageCollector {
template
<
typename
Container
,
typename
Callback
>
template
<
typename
Container
,
typename
Callback
>
void
Add
(
const
Container
&
objs
,
Callback
&&
callback
)
{
void
Add
(
const
Container
&
objs
,
Callback
&&
callback
)
{
std
::
shared_ptr
<
std
::
deque
<
T
*>>
clear_deque
;
std
::
deque
<
T
*>
*
clear_deque
=
nullptr
;
{
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
for
(
auto
*
obj
:
objs
)
{
for
(
auto
*
obj
:
objs
)
{
...
@@ -58,7 +66,7 @@ class GarbageCollector {
...
@@ -58,7 +66,7 @@ class GarbageCollector {
}
}
if
(
cur_memory_size_
>=
max_memory_size_
)
{
if
(
cur_memory_size_
>=
max_memory_size_
)
{
cur_memory_size_
=
0
;
cur_memory_size_
=
0
;
clear_deque
=
garbages_
;
clear_deque
=
garbages_
.
release
()
;
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
}
}
}
}
...
@@ -67,6 +75,7 @@ class GarbageCollector {
...
@@ -67,6 +75,7 @@ class GarbageCollector {
callback
();
callback
();
ClearCallback
([
clear_deque
]()
{
ClearCallback
([
clear_deque
]()
{
for
(
auto
*
obj
:
*
clear_deque
)
obj
->
clear
();
for
(
auto
*
obj
:
*
clear_deque
)
obj
->
clear
();
delete
clear_deque
;
});
});
}
}
}
}
...
@@ -77,7 +86,7 @@ class GarbageCollector {
...
@@ -77,7 +86,7 @@ class GarbageCollector {
virtual
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
=
0
;
virtual
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
=
0
;
platform
::
DeviceContext
*
dev_ctx_
;
platform
::
DeviceContext
*
dev_ctx_
;
std
::
shared
_ptr
<
std
::
deque
<
T
*>>
garbages_
;
std
::
unique
_ptr
<
std
::
deque
<
T
*>>
garbages_
;
mutable
std
::
mutex
mutex_
;
mutable
std
::
mutex
mutex_
;
const
size_t
max_memory_size_
;
const
size_t
max_memory_size_
;
size_t
cur_memory_size_
=
0
;
size_t
cur_memory_size_
=
0
;
...
@@ -96,6 +105,19 @@ class CPUGarbageCollector : public GarbageCollector<T> {
...
@@ -96,6 +105,19 @@ class CPUGarbageCollector : public GarbageCollector<T> {
};
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
template
<
typename
T
>
class
UnsafeFastGPUGarbageCollector
:
public
GarbageCollector
<
T
>
{
public:
UnsafeFastGPUGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{}
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
callback
();
}
};
template
<
typename
T
>
template
<
typename
T
>
class
DefaultStreamGarbageCollector
:
public
GarbageCollector
<
T
>
{
class
DefaultStreamGarbageCollector
:
public
GarbageCollector
<
T
>
{
public:
public:
...
@@ -109,7 +131,7 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> {
...
@@ -109,7 +131,7 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> {
}
}
void
Wait
()
const
override
{
void
Wait
()
const
override
{
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
WaitStreamCallback
();
->
WaitStreamCallback
();
}
}
...
@@ -126,31 +148,23 @@ class StreamGarbageCollector : public GarbageCollector<T> {
...
@@ -126,31 +148,23 @@ class StreamGarbageCollector : public GarbageCollector<T> {
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{
platform
::
SetDeviceI
d
(
place
.
device
);
platform
::
CUDADeviceGuard
guar
d
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
}
}
~
StreamGarbageCollector
()
{
~
StreamGarbageCollector
()
{
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
platform
::
SetDeviceI
d
(
place
.
device
);
platform
::
CUDADeviceGuard
guar
d
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
}
void
Wait
()
const
override
{
void
Wait
()
const
override
{
callback_manager_
->
Wait
();
}
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
std
::
lock_guard
<
std
::
mutex
>
guard
(
this
->
mutex_
);
callback_manager_
->
Wait
();
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
protected:
protected:
// ClearCallback and Wait()/Reset() cannot be call in multiple threads
// But it is not important, because they would not be called in multiple
// threads
// either in Executor or ParallelExecutor
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
callback_manager_
->
AddCallback
(
callback
);
callback_manager_
->
AddCallback
(
callback
);
}
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
c47c451a
...
@@ -873,6 +873,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -873,6 +873,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
}
if
(
t
!=
nullptr
)
{
if
(
t
!=
nullptr
)
{
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s is not initialized: %s"
,
ipt_name
,
DebugString
());
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
tmp
==
data_type
||
data_type
==
-
1
,
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
c47c451a
...
@@ -158,8 +158,13 @@ ParallelExecutor::ParallelExecutor(
...
@@ -158,8 +158,13 @@ ParallelExecutor::ParallelExecutor(
auto
&
place
=
member_
->
places_
[
i
];
auto
&
place
=
member_
->
places_
[
i
];
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
platform
::
is_gpu_place
(
place
))
{
member_
->
gcs_
.
emplace_back
(
new
StreamGarbageCollector
<
Tensor
>
(
if
(
IsFastEagerDeletionModeEnabled
())
{
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
member_
->
gcs_
.
emplace_back
(
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
}
else
{
member_
->
gcs_
.
emplace_back
(
new
StreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
}
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
#endif
#endif
...
@@ -181,8 +186,8 @@ ParallelExecutor::ParallelExecutor(
...
@@ -181,8 +186,8 @@ ParallelExecutor::ParallelExecutor(
&
(
member_
->
rt_ref_cnts_
));
&
(
member_
->
rt_ref_cnts_
));
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
&
last_live_ops_of_vars
);
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
auto
eager_deletion_pass
=
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
...
@@ -194,6 +199,8 @@ ParallelExecutor::ParallelExecutor(
...
@@ -194,6 +199,8 @@ ParallelExecutor::ParallelExecutor(
&
last_live_ops_of_vars
);
&
last_live_ops_of_vars
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
graph
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
(
member_
->
gcs_
));
}
}
// Step 3. Create vars in each scope. Passes may also create new vars.
// Step 3. Create vars in each scope. Passes may also create new vars.
...
...
paddle/fluid/framework/scope.cc
浏览文件 @
c47c451a
...
@@ -38,6 +38,10 @@ DEFINE_double(
...
@@ -38,6 +38,10 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors."
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"
);
"Disabled when this value is less than 0"
);
DEFINE_bool
(
fast_eager_deletion_mode
,
true
,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends."
);
// When in inference scenario, the scopes will not be written by two threads in
// When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and
// a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue.
// the mutex will cause serious performance issue.
...
@@ -58,6 +62,8 @@ int64_t GetEagerDeletionThreshold() {
...
@@ -58,6 +62,8 @@ int64_t GetEagerDeletionThreshold() {
(
static_cast
<
int64_t
>
(
1
)
<<
30
));
(
static_cast
<
int64_t
>
(
1
)
<<
30
));
}
}
bool
IsFastEagerDeletionModeEnabled
()
{
return
FLAGS_fast_eager_deletion_mode
;
}
Scope
::~
Scope
()
{
DropKids
();
}
Scope
::~
Scope
()
{
DropKids
();
}
Scope
&
Scope
::
NewScope
()
const
{
Scope
&
Scope
::
NewScope
()
const
{
...
...
paddle/fluid/framework/scope.h
浏览文件 @
c47c451a
...
@@ -27,6 +27,7 @@ namespace paddle {
...
@@ -27,6 +27,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
int64_t
GetEagerDeletionThreshold
();
int64_t
GetEagerDeletionThreshold
();
bool
IsFastEagerDeletionModeEnabled
();
class
Scope
;
class
Scope
;
...
...
paddle/fluid/framework/tensor.h
浏览文件 @
c47c451a
...
@@ -153,7 +153,7 @@ class Tensor {
...
@@ -153,7 +153,7 @@ class Tensor {
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
void
clear
()
{
holder_
=
nullptr
;
}
void
clear
()
{
holder_
.
reset
()
;
}
const
std
::
shared_ptr
<
memory
::
Allocation
>&
Holder
()
const
{
return
holder_
;
}
const
std
::
shared_ptr
<
memory
::
Allocation
>&
Holder
()
const
{
return
holder_
;
}
size_t
offset
()
const
{
return
offset_
;
}
size_t
offset
()
const
{
return
offset_
;
}
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
c47c451a
...
@@ -59,7 +59,21 @@ class WhileOp : public framework::OperatorBase {
...
@@ -59,7 +59,21 @@ class WhileOp : public framework::OperatorBase {
"Condition of while op must in CPU memory."
);
"Condition of while op must in CPU memory."
);
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
());
auto
&
skip_eager_deletion_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"skip_eager_deletion_vars"
);
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
&&
VLOG_IS_ON
(
10
))
{
std
::
string
debug_string
=
"Skip "
+
std
::
to_string
(
skip_eager_deletion_vars
.
size
())
+
" vars in eager deletion mode: "
;
for
(
auto
&
var
:
skip_eager_deletion_vars
)
{
debug_string
.
append
(
var
);
debug_string
.
push_back
(
' '
);
}
VLOG
(
10
)
<<
debug_string
;
}
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
(),
skip_eager_deletion_vars
);
while
(
cond
.
data
<
bool
>
()[
0
])
{
while
(
cond
.
data
<
bool
>
()[
0
])
{
auto
&
current_scope
=
scope
.
NewScope
();
auto
&
current_scope
=
scope
.
NewScope
();
step_scopes
->
push_back
(
&
current_scope
);
step_scopes
->
push_back
(
&
current_scope
);
...
@@ -96,6 +110,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -96,6 +110,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"skip_eager_deletion_vars"
,
"Vars that would skip eager deletion."
"Users should not set this manually."
)
.
SetDefault
(
std
::
vector
<
std
::
string
>
());
AddComment
(
R"DOC(
AddComment
(
R"DOC(
)DOC"
);
)DOC"
);
}
}
...
@@ -341,6 +359,30 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -341,6 +359,30 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed.
// while operator could be renamed.
while_grad
->
SetAttr
(
"original_output_grad"
,
output_grads_list
);
while_grad
->
SetAttr
(
"original_output_grad"
,
output_grads_list
);
/* The following codes are used in eager deletion mode */
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
std
::
unordered_set
<
std
::
string
>
skip_vars
;
for
(
auto
*
op_desc
:
grad_block
->
AllOps
())
{
for
(
auto
&
in_arg_name
:
op_desc
->
InputArgumentNames
())
{
// If input var of ops inside grad_block is not from grad_block,
// it cannot be deleted when forward while_op runs
if
(
in_arg_name
!=
framework
::
kEmptyVarName
&&
!
grad_block
->
HasVar
(
in_arg_name
))
{
skip_vars
.
insert
(
in_arg_name
);
}
}
}
if
(
!
skip_vars
.
empty
())
{
// FIXME(zjl): ugly const_cast here, maybe we should find a better way
// to modify forward while_op
auto
&
fwd_while_op
=
const_cast
<
framework
::
OpDesc
&>
(
ForwardOp
());
fwd_while_op
.
SetAttr
(
"skip_eager_deletion_vars"
,
std
::
vector
<
std
::
string
>
(
skip_vars
.
begin
(),
skip_vars
.
end
()));
}
}
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
while_grad
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
while_grad
);
}
}
};
};
...
...
paddle/fluid/operators/reader/ctr_reader.h
浏览文件 @
c47c451a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <sys/time.h>
#include <sys/time.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <chrono> // NOLINT
#include <cstdlib>
#include <cstdlib>
#include <fstream>
#include <fstream>
...
@@ -55,8 +56,7 @@ class CTRReader : public framework::FileReader {
...
@@ -55,8 +56,7 @@ class CTRReader : public framework::FileReader {
PADDLE_ENFORCE_GT
(
thread_num
,
0
,
"thread num should be larger then 0!"
);
PADDLE_ENFORCE_GT
(
thread_num
,
0
,
"thread num should be larger then 0!"
);
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE_GT
(
file_list
.
size
(),
0
,
"file list should not be empty"
);
PADDLE_ENFORCE_GT
(
file_list
.
size
(),
0
,
"file list should not be empty"
);
thread_num_
=
thread_num_
=
std
::
min
<
size_t
>
(
file_list_
.
size
(),
thread_num
);
file_list_
.
size
()
>
thread_num
?
thread_num
:
file_list_
.
size
();
queue_
=
queue
;
queue_
=
queue
;
SplitFiles
();
SplitFiles
();
for
(
size_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
...
@@ -95,10 +95,10 @@ class CTRReader : public framework::FileReader {
...
@@ -95,10 +95,10 @@ class CTRReader : public framework::FileReader {
queue_
->
ReOpen
();
queue_
->
ReOpen
();
VLOG
(
3
)
<<
"reopen success"
;
VLOG
(
3
)
<<
"reopen success"
;
VLOG
(
3
)
<<
"thread_num "
<<
thread_num_
;
VLOG
(
3
)
<<
"thread_num "
<<
thread_num_
;
for
(
in
t
thread_id
=
0
;
thread_id
<
thread_num_
;
thread_id
++
)
{
for
(
size_
t
thread_id
=
0
;
thread_id
<
thread_num_
;
thread_id
++
)
{
read_threads_
.
emplace_back
(
new
std
::
thread
(
read_threads_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
std
::
bind
(
&
ReadThread
,
file_groups_
[
thread_id
],
slots_
,
batch_size_
,
&
ReadThread
,
file_groups_
[
thread_id
],
slots_
,
batch_size_
,
thread_id
,
&
read_thread_status_
,
queue_
)));
static_cast
<
int
>
(
thread_id
)
,
&
read_thread_status_
,
queue_
)));
}
}
monitor_thread_
.
reset
(
new
std
::
thread
(
monitor_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
MonitorThread
,
&
read_thread_status_
,
queue_
)));
std
::
bind
(
&
MonitorThread
,
&
read_thread_status_
,
queue_
)));
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
c47c451a
...
@@ -223,14 +223,10 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -223,14 +223,10 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
template
<
typename
Callback
>
void
AddStreamCallback
(
Callback
&&
callback
)
const
{
void
AddStreamCallback
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
callback_mtx_
);
callback_manager_
->
AddCallback
(
callback
);
callback_manager_
->
AddCallback
(
callback
);
}
}
void
WaitStreamCallback
()
const
{
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
std
::
lock_guard
<
std
::
mutex
>
guard
(
callback_mtx_
);
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
/*! \brief CublasCall may need to change cublas's config,
...
@@ -261,9 +257,7 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -261,9 +257,7 @@ class CUDADeviceContext : public DeviceContext {
mutable
std
::
mutex
mtx_
;
mutable
std
::
mutex
mtx_
;
// This lock is only used by callback
// StreamCallbackManager is thread-safe
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable
std
::
mutex
callback_mtx_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
mutable
std
::
mutex
cublas_mtx_
;
mutable
std
::
mutex
cublas_mtx_
;
...
...
paddle/fluid/platform/stream_callback_manager.cc
浏览文件 @
c47c451a
...
@@ -18,52 +18,47 @@
...
@@ -18,52 +18,47 @@
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
struct
StreamCallbackContext
{
#if CUDA_VERSION >= 10000
inline
StreamCallbackContext
(
const
StreamCallbackManager
*
manager
,
static
void
CUDART_CB
StreamCallbackFunc
(
void
*
user_data
);
std
::
function
<
void
()
>
callback
)
#else
:
manager_
(
manager
),
callback_
(
std
::
move
(
callback
))
{}
static
void
CUDART_CB
StreamCallbackFunc
(
cudaStream_t
stream
,
cudaError_t
status
,
void
*
user_data
)
const
StreamCallbackManager
*
manager_
;
// do not own
#endif
std
::
function
<
void
()
>
callback_
;
{
};
std
::
unique_ptr
<
std
::
function
<
void
()
>>
func
(
reinterpret_cast
<
std
::
function
<
void
()
>
*>
(
user_data
));
(
*
func
)();
}
StreamCallbackManager
::
StreamCallbackManager
(
const
cudaStream_t
stream
)
StreamCallbackManager
::
StreamCallbackManager
(
const
cudaStream_t
stream
)
:
stream_
(
stream
),
thread_pool_
(
new
::
ThreadPool
(
1
)
)
{}
:
stream_
(
stream
),
thread_pool_
(
1
)
{}
void
StreamCallbackManager
::
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
{
void
StreamCallbackManager
::
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
{
auto
*
stream_callback_context
=
auto
*
callback_func
=
new
std
::
function
<
void
()
>
(
std
::
move
(
callback
));
new
StreamCallbackContext
(
this
,
std
::
move
(
callback
));
auto
*
func
=
new
std
::
function
<
void
()
>
([
this
,
callback_func
]
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx_
);
last_future_
=
thread_pool_
.
enqueue
([
callback_func
]
{
std
::
unique_ptr
<
std
::
function
<
void
()
>>
releaser
(
callback_func
);
(
*
callback_func
)();
});
});
#if CUDA_VERSION >= 10000
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE
(
cudaLaunchHostFunc
(
stream_
,
PADDLE_ENFORCE
(
cudaLaunchHostFunc
(
stream_
,
StreamCallbackFunc
,
func
));
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
));
#else
#else
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
cudaStreamAddCallback
(
stream_
,
StreamCallbackFunc
,
func
,
0
));
cudaStreamAddCallback
(
stream_
,
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
,
0
));
#endif
#endif
}
}
void
StreamCallbackManager
::
Wait
()
const
{
StreamCallbackManager
::~
StreamCallbackManager
()
{
Wait
();
}
thread_pool_
.
reset
(
new
::
ThreadPool
(
1
));
}
#if CUDA_VERSION >= 10000
void
StreamCallbackManager
::
Wait
()
const
{
void
CUDART_CB
StreamCallbackManager
::
StreamCallbackFunc
(
void
*
user_data
)
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
#else
{
void
CUDART_CB
StreamCallbackManager
::
StreamCallbackFunc
(
cudaStream_t
stream
,
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx_
);
cudaError_t
status
,
if
(
last_future_
.
valid
())
{
void
*
user_data
)
last_future_
.
wait
();
#endif
}
{
}
auto
*
callback_context_ptr
=
reinterpret_cast
<
StreamCallbackContext
*>
(
user_data
);
callback_context_ptr
->
manager_
->
thread_pool_
->
enqueue
(
[
callback_context_ptr
]()
{
std
::
unique_ptr
<
StreamCallbackContext
>
callback_context
(
callback_context_ptr
);
callback_context
->
callback_
();
});
}
}
}
// namespace platform
}
// namespace platform
...
...
paddle/fluid/platform/stream_callback_manager.h
浏览文件 @
c47c451a
...
@@ -18,30 +18,32 @@
...
@@ -18,30 +18,32 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <functional>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
// NOTE(zjl): clean StreamCallback to make compilation faster
// NOTE(zjl): clean StreamCallbackManager to make compilation faster
// Make StreamCallbackManager thread-safe
class
StreamCallbackManager
{
class
StreamCallbackManager
{
public:
public:
explicit
StreamCallbackManager
(
const
cudaStream_t
stream
);
explicit
StreamCallbackManager
(
const
cudaStream_t
stream
);
~
StreamCallbackManager
();
void
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
;
void
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
;
void
Wait
()
const
;
void
Wait
()
const
;
private:
private:
const
cudaStream_t
stream_
;
const
cudaStream_t
stream_
;
mutable
std
::
unique_ptr
<::
ThreadPool
>
thread_pool_
;
mutable
::
ThreadPool
thread_pool_
;
mutable
std
::
mutex
mtx_
;
#if CUDA_VERSION >= 10000
mutable
std
::
future
<
void
>
last_future_
;
static
void
CUDART_CB
StreamCallbackFunc
(
void
*
user_data
);
#else
static
void
CUDART_CB
StreamCallbackFunc
(
cudaStream_t
stream
,
cudaError_t
status
,
void
*
user_data
);
#endif
};
};
}
// namespace platform
}
// namespace platform
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
c47c451a
...
@@ -162,7 +162,7 @@ void PyCPUTensorSetFromArray(
...
@@ -162,7 +162,7 @@ void PyCPUTensorSetFromArray(
paddle
::
platform
::
CPUPlace
place
)
{
paddle
::
platform
::
CPUPlace
place
)
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
}
}
...
@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray(
...
@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray(
paddle
::
platform
::
CPUPlace
place
)
{
paddle
::
platform
::
CPUPlace
place
)
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
}
}
...
@@ -200,7 +200,7 @@ void PyCUDATensorSetFromArray(
...
@@ -200,7 +200,7 @@ void PyCUDATensorSetFromArray(
paddle
::
platform
::
CUDAPlace
place
)
{
paddle
::
platform
::
CUDAPlace
place
)
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
}
}
...
@@ -221,7 +221,7 @@ inline void PyCUDATensorSetFromArray(
...
@@ -221,7 +221,7 @@ inline void PyCUDATensorSetFromArray(
paddle
::
platform
::
CUDAPlace
place
)
{
paddle
::
platform
::
CUDAPlace
place
)
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
}
}
...
@@ -240,7 +240,7 @@ void PyCUDAPinnedTensorSetFromArray(
...
@@ -240,7 +240,7 @@ void PyCUDAPinnedTensorSetFromArray(
const
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
{
const
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
}
}
...
@@ -260,7 +260,7 @@ inline void PyCUDAPinnedTensorSetFromArray(
...
@@ -260,7 +260,7 @@ inline void PyCUDAPinnedTensorSetFromArray(
const
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
{
const
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
}
}
...
...
python/paddle/fluid/__init__.py
浏览文件 @
c47c451a
...
@@ -116,8 +116,9 @@ def __bootstrap__():
...
@@ -116,8 +116,9 @@ def __bootstrap__():
'check_nan_inf'
,
'benchmark'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'check_nan_inf'
,
'benchmark'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'use_ngraph'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
,
'use_ngraph'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
,
'free_idle_memory'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'free_idle_memory'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'eager_delete_tensor_gb'
,
'allocator_strategy'
,
'eager_delete_tensor_gb'
,
'fast_eager_deletion_mode'
,
'reader_queue_speed_test_mode'
,
'print_sub_graph_dir'
'allocator_strategy'
,
'reader_queue_speed_test_mode'
,
'print_sub_graph_dir'
]
]
if
'Darwin'
not
in
sysstr
:
if
'Darwin'
not
in
sysstr
:
read_env_flags
.
append
(
'use_pinned_memory'
)
read_env_flags
.
append
(
'use_pinned_memory'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录