Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
096673f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
096673f6
编写于
11月 29, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor eager deletion
test=develop
上级
400cf19f
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
631 addition
and
441 deletion
+631
-441
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-8
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+4
-2
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+5
-1
paddle/fluid/framework/details/eager_deletion_op_handle.cc
paddle/fluid/framework/details/eager_deletion_op_handle.cc
+117
-0
paddle/fluid/framework/details/eager_deletion_op_handle.h
paddle/fluid/framework/details/eager_deletion_op_handle.h
+64
-0
paddle/fluid/framework/details/eager_deletion_pass.cc
paddle/fluid/framework/details/eager_deletion_pass.cc
+96
-0
paddle/fluid/framework/details/eager_deletion_pass.h
paddle/fluid/framework/details/eager_deletion_pass.h
+32
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+3
-3
paddle/fluid/framework/details/reference_count_op_handle.h
paddle/fluid/framework/details/reference_count_op_handle.h
+0
-138
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+62
-151
paddle/fluid/framework/details/reference_count_pass.h
paddle/fluid/framework/details/reference_count_pass.h
+0
-5
paddle/fluid/framework/details/reference_count_pass_helper.h
paddle/fluid/framework/details/reference_count_pass_helper.h
+49
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+10
-20
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+4
-0
paddle/fluid/framework/garbage_collector.h
paddle/fluid/framework/garbage_collector.h
+7
-5
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+9
-2
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+9
-2
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+70
-36
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+1
-23
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+8
-1
paddle/fluid/platform/stream_callback_manager.cc
paddle/fluid/platform/stream_callback_manager.cc
+70
-0
paddle/fluid/platform/stream_callback_manager.h
paddle/fluid/platform/stream_callback_manager.h
+7
-44
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
096673f6
...
@@ -33,10 +33,9 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
...
@@ -33,10 +33,9 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
if
(
WITH_GPU
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass
)
endif
()
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
...
@@ -44,10 +43,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
...
@@ -44,10 +43,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass
)
if
(
WITH_GPU
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
endif
()
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
...
...
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
096673f6
...
@@ -20,11 +20,13 @@ namespace paddle {
...
@@ -20,11 +20,13 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
ComputationOpHandle
::
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
ComputationOpHandle
::
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
)
platform
::
Place
place
,
size_t
scope_idx
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
op_
(
framework
::
OpRegistry
::
CreateOp
(
*
node
->
Op
())),
op_
(
framework
::
OpRegistry
::
CreateOp
(
*
node
->
Op
())),
scope_
(
scope
),
scope_
(
scope
),
place_
(
place
)
{}
place_
(
place
),
scope_idx_
(
scope_idx
)
{}
void
ComputationOpHandle
::
RunImpl
()
{
void
ComputationOpHandle
::
RunImpl
()
{
WaitInputVarGenerated
(
place_
);
WaitInputVarGenerated
(
place_
);
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
096673f6
...
@@ -28,7 +28,8 @@ namespace framework {
...
@@ -28,7 +28,8 @@ namespace framework {
namespace
details
{
namespace
details
{
struct
ComputationOpHandle
:
public
OpHandleBase
{
struct
ComputationOpHandle
:
public
OpHandleBase
{
public:
public:
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
);
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
,
size_t
scope_idx
);
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
...
@@ -38,6 +39,8 @@ struct ComputationOpHandle : public OpHandleBase {
...
@@ -38,6 +39,8 @@ struct ComputationOpHandle : public OpHandleBase {
void
SetLockAndRecordEventFree
(
bool
b
)
{
is_lock_and_record_event_free_
=
b
;
}
void
SetLockAndRecordEventFree
(
bool
b
)
{
is_lock_and_record_event_free_
=
b
;
}
size_t
GetScopeIdx
()
const
{
return
scope_idx_
;
}
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
...
@@ -47,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
...
@@ -47,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
std
::
unique_ptr
<
OperatorBase
>
op_
;
std
::
unique_ptr
<
OperatorBase
>
op_
;
Scope
*
scope_
;
Scope
*
scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
size_t
scope_idx_
;
bool
is_lock_and_record_event_free_
{
false
};
bool
is_lock_and_record_event_free_
{
false
};
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.cc
0 → 100644
浏览文件 @
096673f6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
scope_
(
scope
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
))
{
platform
::
SetDeviceId
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
);
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
}
}
#endif
for
(
auto
&
name
:
var_names
)
AddVar
(
name
);
}
EagerDeletionOpHandle
::~
EagerDeletionOpHandle
()
{
#ifdef PADDLE_WITH_CUDA
if
(
event_
)
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
platform
::
SetDeviceId
(
gpu_place
.
device
);
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
}
#endif
}
std
::
string
EagerDeletionOpHandle
::
Name
()
const
{
return
"eager_deletion"
;
}
void
EagerDeletionOpHandle
::
AddVar
(
const
std
::
string
&
name
)
{
var_names_
.
insert
(
name
);
}
void
EagerDeletionOpHandle
::
RunImpl
()
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
Tensor
*>
tensors
;
for
(
auto
&
name
:
var_names_
)
{
auto
it
=
ref_cnts_
->
find
(
name
);
if
(
it
==
ref_cnts_
->
end
())
{
continue
;
}
auto
*
var
=
exec_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
)
{
continue
;
}
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
it
->
second
.
fetch_sub
(
1
)
==
1
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
if
(
it
->
second
.
fetch_sub
(
1
)
==
1
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
}
else
if
(
var
->
IsType
<
LoDTensorArray
>
())
{
if
(
it
->
second
.
fetch_sub
(
1
)
==
1
)
{
auto
*
tensor_arr
=
var
->
GetMutable
<
LoDTensorArray
>
();
for
(
auto
&
t
:
*
tensor_arr
)
{
tensors
.
emplace_back
(
&
t
);
}
}
}
}
if
(
!
tensors
.
empty
())
{
ClearTensors
(
tensors
);
}
}
void
EagerDeletionOpHandle
::
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
)
{
#ifdef PADDLE_WITH_CUDA
if
(
event_
)
{
auto
compute_stream
=
dev_ctx_
->
stream
();
auto
callback_stream
=
static_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
)
->
stream
();
auto
callback_func
=
[
=
]()
{
PADDLE_ENFORCE
(
cudaEventRecord
(
event_
,
compute_stream
));
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
callback_stream
,
event_
,
0
));
};
gc_
->
Add
(
tensors
,
callback_func
);
}
else
{
#endif
gc_
->
Add
(
tensors
);
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/eager_deletion_op_handle.h
0 → 100644
浏览文件 @
096673f6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
class
EagerDeletionPass
;
class
EagerDeletionOpHandle
:
public
OpHandleBase
{
public:
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
);
~
EagerDeletionOpHandle
();
std
::
string
Name
()
const
override
;
protected:
void
RunImpl
()
override
;
private:
void
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
);
void
AddVar
(
const
std
::
string
&
name
);
const
Scope
*
scope_
;
std
::
unordered_set
<
std
::
string
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
#ifdef PADDLE_WITH_CUDA
platform
::
CUDADeviceContext
*
dev_ctx_
{
nullptr
};
cudaEvent_t
event_
{
nullptr
};
#endif
friend
class
EagerDeletionPass
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/eager_deletion_pass.cc
0 → 100644
浏览文件 @
096673f6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
void
AddDependencyBetween
(
OpHandleBase
*
in
,
OpHandleBase
*
out
,
ir
::
Graph
*
graph
)
{
auto
it
=
std
::
find_if
(
in
->
Outputs
().
begin
(),
in
->
Outputs
().
end
(),
[](
VarHandleBase
*
var
)
{
return
dynamic_cast
<
DummyVarHandle
*>
(
var
)
!=
nullptr
;
});
if
(
it
!=
in
->
Outputs
().
end
())
{
out
->
AddInput
(
*
it
);
}
else
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
in
->
AddOutput
(
dep_var
);
out
->
AddInput
(
dep_var
);
}
// Add leaf node to eager_deletion_node
if
(
out
->
Outputs
().
empty
())
{
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
out
->
AddOutput
(
dummy_leaf
);
}
}
std
::
unique_ptr
<
ir
::
Graph
>
EagerDeletionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
kCurReferenceCount
);
auto
&
last_live_ops
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
auto
&
gcs
=
Get
<
GarbageCollectorList
>
(
kGarbageCollector
);
ref_cnts
=
std
::
vector
<
AtomicReferenceCountMap
>
(
vars
.
size
());
std
::
unordered_map
<
ComputationOpHandle
*
,
EagerDeletionOpHandle
*>
op_map
;
for
(
auto
&
var_ops_map
:
last_live_ops
)
{
for
(
auto
&
var_ops_pair
:
var_ops_map
)
{
const
std
::
string
&
var_name
=
var_ops_pair
.
first
;
for
(
ComputationOpHandle
*
op
:
var_ops_pair
.
second
)
{
auto
it
=
op_map
.
find
(
op
);
if
(
it
!=
op_map
.
end
())
{
it
->
second
->
AddVar
(
var_name
);
}
else
{
auto
*
eager_deletion_node
=
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
{
var_name
},
gcs
[
op
->
GetScopeIdx
()].
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
AddDependencyBetween
(
op
,
eager_deletion_op
,
graph
.
get
());
op_map
[
op
]
=
eager_deletion_op
;
}
}
}
}
VLOG
(
10
)
<<
"Create "
<<
op_map
.
size
()
<<
" EagerDeletionOpHandle(s)"
;
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
eager_deletion_pass
,
paddle
::
framework
::
details
::
EagerDeletionPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kCurReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/eager_deletion_pass.h
0 → 100644
浏览文件 @
096673f6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
EagerDeletionPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
096673f6
...
@@ -562,7 +562,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
...
@@ -562,7 +562,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
local_scopes_
[
dev_id
],
places_
[
dev_id
]
,
dev_id
));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
}
}
...
@@ -685,8 +685,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
...
@@ -685,8 +685,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
,
scope_idx
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
}
}
}
...
...
paddle/fluid/framework/details/reference_count_op_handle.h
已删除
100644 → 0
浏览文件 @
400cf19f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
using
ReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
using
AtomicReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
std
::
atomic
<
int
>>
;
using
DeviceReferenceCountMap
=
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
ReferenceCountMap
>>
;
using
AtomicDeviceReferenceCountMap
=
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
AtomicReferenceCountMap
>>
;
using
DeviceGarbageCollectorMap
=
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
GarbageCollector
<
framework
::
Tensor
>>>
;
class
ReferenceCountOpHandle
:
public
OpHandleBase
{
public:
ReferenceCountOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
CUDAPlace
&
place
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
scope_
(
scope
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
IsStreamGarabageCollector
())
{
platform
::
SetDeviceId
(
place
.
device
);
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
}
for
(
auto
&
name
:
var_names
)
AddVar
(
name
);
}
~
ReferenceCountOpHandle
()
{
if
(
IsStreamGarabageCollector
())
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
platform
::
SetDeviceId
(
gpu_place
.
device
);
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
}
}
std
::
string
Name
()
const
override
{
return
"reference_count"
;
}
void
AddVar
(
const
std
::
string
&
name
)
{
auto
it
=
var_names_
.
find
(
name
);
if
(
it
!=
var_names_
.
end
())
++
(
it
->
second
);
else
var_names_
[
name
]
=
1
;
}
protected:
void
RunImpl
()
override
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
Tensor
*>
tensors
;
for
(
auto
&
pair
:
var_names_
)
{
auto
&
name
=
pair
.
first
;
auto
it
=
ref_cnts_
->
find
(
name
);
if
(
it
==
ref_cnts_
->
end
())
continue
;
auto
*
var
=
exec_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
it
->
second
.
fetch_sub
(
pair
.
second
)
<=
pair
.
second
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
if
(
it
->
second
.
fetch_sub
(
pair
.
second
)
<=
pair
.
second
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
}
}
if
(
!
tensors
.
empty
())
{
ClearTensors
(
tensors
);
}
}
private:
void
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
)
{
auto
*
gc
=
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
);
if
(
gc
!=
nullptr
)
{
auto
compute_stream
=
dev_ctx_
->
stream
();
auto
callback_stream
=
gc
->
stream
();
auto
callback_func
=
[
=
]()
{
PADDLE_ENFORCE
(
cudaEventRecord
(
event_
,
compute_stream
));
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
callback_stream
,
event_
,
0
));
};
gc_
->
Add
(
tensors
,
callback_func
);
}
else
{
gc_
->
Add
(
tensors
);
}
}
bool
IsStreamGarabageCollector
()
const
{
return
dynamic_cast
<
const
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
)
!=
nullptr
;
}
const
Scope
*
scope_
;
platform
::
CUDADeviceContext
*
dev_ctx_
;
std
::
unordered_map
<
std
::
string
,
int
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
cudaEvent_t
event_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
096673f6
...
@@ -17,184 +17,96 @@
...
@@ -17,184 +17,96 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
ComputationOpHandle
*
FindNextComputationOpHandle
(
VarHandle
*
var_in
)
{
static
ComputationOpHandle
*
FindNextComputationOpHandleOrReturnItself
(
std
::
queue
<
VarHandleBase
*>
queue
;
OpHandleBase
*
op
,
size_t
scope_idx
)
{
queue
.
push
(
var_in
);
std
::
queue
<
OpHandleBase
*>
q
;
std
::
unordered_set
<
OpHandleBase
*>
visited
;
q
.
push
(
op
);
do
{
do
{
auto
*
var
=
queue
.
front
();
auto
*
op
=
q
.
front
();
queue
.
pop
();
q
.
pop
();
for
(
auto
*
op
:
var
->
PendingOps
())
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetScopeIdx
()
==
scope_idx
)
{
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetPlace
()
==
var_in
->
place_
)
{
return
compute_op
;
return
compute_op
;
}
}
for
(
auto
*
out_var
:
op
->
Outputs
())
{
for
(
auto
*
out_var
:
op
->
Outputs
())
{
for
(
auto
*
pending_op
:
out_var
->
PendingOps
())
{
queue
.
push
(
out_var
);
if
(
visited
.
count
(
pending_op
))
continue
;
visited
.
insert
(
pending_op
);
}
}
}
}
}
while
(
!
q
ueue
.
empty
());
}
while
(
!
q
.
empty
());
return
nullptr
;
return
nullptr
;
}
}
static
void
AddDependencyBetween
(
OpHandleBase
*
in
,
OpHandleBase
*
out
,
ir
::
Graph
*
graph
)
{
auto
it
=
std
::
find_if
(
in
->
Outputs
().
begin
(),
in
->
Outputs
().
end
(),
[](
VarHandleBase
*
var
)
{
return
dynamic_cast
<
DummyVarHandle
*>
(
var
)
!=
nullptr
;
});
if
(
it
!=
in
->
Outputs
().
end
())
{
out
->
AddInput
(
*
it
);
}
else
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
in
->
AddOutput
(
dep_var
);
out
->
AddInput
(
dep_var
);
}
}
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
cur_ref_cnts
=
Get
<
AtomicDeviceReferenceCountMap
>
(
kCurReferenceCount
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
ReferenceCountMap
>>
(
kGlobalReferenceCount
);
auto
&
gcs
=
Get
<
DeviceGarbageCollectorMap
>
(
kGarbageCollector
);
auto
&
last_live_ops_of_vars
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
last_live_ops_of_vars
=
std
::
vector
<
LastLiveOpsOfVars
>
(
vars
.
size
());
// Step 2: Find all variables in non-computation ops which refers to variables
ref_cnts
=
std
::
vector
<
ReferenceCountMap
>
(
vars
.
size
());
// in computation ops
std
::
unordered_set
<
std
::
string
>
names
;
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
std
::
unordered_map
<
OpHandleBase
*
,
ReferenceCountOpHandle
*>
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
compute_ref_cnt_map
;
if
(
name_var_pair
.
second
.
empty
())
continue
;
auto
*
last_ver_var
=
name_var_pair
.
second
.
back
();
auto
get_ref_cnts_from_compute_op
=
[
&
](
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
VarDesc
*
var_desc
=
nullptr
;
std
::
vector
<
std
::
string
>
var_names_in_op
;
std
::
find_if
(
name_var_pair
.
second
.
rbegin
(),
name_var_pair
.
second
.
rend
(),
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
[
&
](
VarHandle
*
var_handle
)
->
bool
{
if
(
compute_op
==
nullptr
||
var_desc
=
var_handle
->
Node
()
->
Var
();
!
platform
::
is_gpu_place
(
compute_op
->
GetPlace
()))
return
var_desc
!=
nullptr
;
return
var_names_in_op
;
});
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
if
(
var_desc
==
nullptr
||
var_desc
->
Persistable
())
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
if
(
!
platform
::
is_gpu_place
(
var_handle
->
place_
)
||
boost
::
get
<
platform
::
CUDAPlace
>
(
var_handle
->
place_
)
!=
place
)
continue
;
continue
;
VarDesc
*
var_desc
=
var_handle
->
Node
()
->
Var
();
auto
var_name
=
var_handle
->
Node
()
->
Name
();
// This is weird but there is really some variables without var_desc
// in computation_op
if
(
var_desc
==
nullptr
)
{
var_desc
=
compute_op
->
Node
()
->
Op
()
->
Block
()
->
FindVar
(
var_name
);
if
(
var_desc
==
nullptr
)
continue
;
}
}
if
(
var_desc
->
Persistable
())
continue
;
auto
var_type
=
var_desc
->
Proto
()
->
type
().
type
();
auto
var_type
=
var_desc
->
Proto
()
->
type
().
type
();
if
(
var_type
!=
proto
::
VarType
::
LOD_TENSOR
&&
if
(
var_type
!=
proto
::
VarType
::
LOD_TENSOR
&&
var_type
!=
proto
::
VarType
::
SELECTED_ROWS
)
{
var_type
!=
proto
::
VarType
::
SELECTED_ROWS
&&
var_type
!=
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
continue
;
continue
;
}
}
// compute op only runs in one device
std
::
unordered_set
<
ComputationOpHandle
*>
last_live_op
;
if
(
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
auto
add_last_live_op
=
[
&
](
OpHandleBase
*
op
)
{
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
auto
*
compute_op
=
FindNextComputationOpHandleOrReturnItself
(
op
,
i
);
else
if
(
compute_op
)
{
(
*
ref_cnts
[
place
.
device
])[
var_name
]
=
1
;
last_live_op
.
insert
(
compute_op
);
}
names
.
insert
(
var_name
);
};
var_names_in_op
.
push_back
(
var_name
);
const
std
::
string
&
var_name
=
name_var_pair
.
first
;
}
auto
&
pending_ops
=
last_ver_var
->
PendingOps
();
return
var_names_in_op
;
if
(
pending_ops
.
empty
())
{
};
auto
*
generated_op
=
last_ver_var
->
GeneratedOp
();
if
(
generated_op
)
{
auto
update_ref_cnts_from_non_compute_op
=
[
&
](
ref_cnts
[
i
].
emplace
(
var_name
,
1
);
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
add_last_live_op
(
generated_op
);
if
(
dynamic_cast
<
ComputationOpHandle
*>
(
op
)
!=
nullptr
)
return
;
}
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
}
else
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
ref_cnts
[
i
].
emplace
(
var_name
,
pending_ops
.
size
());
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
for
(
auto
*
pending_op
:
pending_ops
)
{
add_last_live_op
(
pending_op
);
auto
var_name
=
var_handle
->
Node
()
->
Name
();
auto
var_place
=
var_handle
->
place_
;
if
(
!
platform
::
is_gpu_place
(
var_place
))
continue
;
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
var_place
);
if
(
names
.
count
(
var_name
)
==
0
)
continue
;
if
(
ref_cnts
.
count
(
place
.
device
)
&&
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
{
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
auto
*
next_compute_op
=
FindNextComputationOpHandle
(
var_handle
);
if
(
next_compute_op
!=
nullptr
)
{
if
(
compute_ref_cnt_map
.
count
(
next_compute_op
))
{
compute_ref_cnt_map
[
next_compute_op
]
->
AddVar
(
var_name
);
VLOG
(
5
)
<<
"Add reference count of "
<<
var_name
<<
" to Operator "
<<
next_compute_op
->
Name
();
}
else
{
// Create new reference_count_op_handle
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
AddDependencyBetween
(
next_compute_op
,
ref_cnt_handle
,
graph
.
get
());
compute_ref_cnt_map
[
next_compute_op
]
=
ref_cnt_handle
;
}
}
}
}
}
}
};
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
last_live_ops_of_vars
[
i
].
emplace
(
var_name
,
std
::
move
(
last_live_op
));
for
(
auto
&
op
:
all_ops
)
{
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
auto
out_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Outputs
());
if
(
in_var_names
.
empty
()
&&
out_var_names
.
empty
())
continue
;
in_var_names
.
insert
(
in_var_names
.
end
(),
out_var_names
.
begin
(),
out_var_names
.
end
());
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
AddDependencyBetween
(
compute_op
,
ref_cnt_handle
,
graph
.
get
());
compute_ref_cnt_map
[
compute_op
]
=
ref_cnt_handle
;
}
for
(
auto
&
op
:
all_ops
)
{
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Inputs
());
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Outputs
());
}
std
::
vector
<
OpHandleBase
*>
new_all_ops
;
new_all_ops
.
reserve
(
compute_ref_cnt_map
.
size
()
+
all_ops
.
size
());
for
(
auto
&
op
:
all_ops
)
{
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
());
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
// Add LeafNode to ReferenceCountOpHandle
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
it
->
second
->
AddOutput
(
dummy_leaf
);
new_all_ops
.
emplace_back
(
std
::
move
(
it
->
second
));
}
}
}
}
all_ops
.
swap
(
new_all_ops
);
return
graph
;
return
graph
;
}
}
...
@@ -205,5 +117,4 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -205,5 +117,4 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
REGISTER_PASS
(
reference_count_pass
,
REGISTER_PASS
(
reference_count_pass
,
paddle
::
framework
::
details
::
ReferenceCountPass
)
paddle
::
framework
::
details
::
ReferenceCountPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGlobalReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGlobalReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kCurReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/reference_count_pass.h
浏览文件 @
096673f6
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -22,10 +21,6 @@ namespace paddle {
...
@@ -22,10 +21,6 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
constexpr
char
kGlobalReferenceCount
[]
=
"reference_count"
;
constexpr
char
kCurReferenceCount
[]
=
"current_reference_count"
;
constexpr
char
kGarbageCollector
[]
=
"garbage_collector"
;
class
ReferenceCountPass
:
public
ir
::
Pass
{
class
ReferenceCountPass
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
...
...
paddle/fluid/framework/details/reference_count_pass_helper.h
0 → 100644
浏览文件 @
096673f6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
ComputationOpHandle
;
using
ReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
size_t
>
;
using
AtomicReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
std
::
atomic
<
size_t
>>
;
using
GarbageCollectorList
=
std
::
vector
<
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>>
;
const
char
kGlobalReferenceCount
[]
=
"reference_count"
;
const
char
kCurReferenceCount
[]
=
"current_reference_count"
;
const
char
kGarbageCollector
[]
=
"garbage_collector"
;
using
LastLiveOpsOfVars
=
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ComputationOpHandle
*>>
;
const
char
kLastLiveOpsOfVars
[]
=
"last_live_ops_of_var"
;
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
096673f6
...
@@ -18,9 +18,6 @@
...
@@ -18,9 +18,6 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -33,7 +30,11 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
...
@@ -33,7 +30,11 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
underlying_executor_
(
std
::
move
(
underlying_executor
)),
underlying_executor_
(
std
::
move
(
underlying_executor
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
var_infos_
(
std
::
move
(
var_infos
)),
var_infos_
(
std
::
move
(
var_infos
)),
places_
(
std
::
move
(
places
))
{}
places_
(
std
::
move
(
places
))
{
if
(
Graph
().
Has
(
details
::
kGarbageCollector
))
{
gc_
=
&
(
Graph
().
Get
<
GarbageCollectorList
>
(
details
::
kGarbageCollector
));
}
}
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
...
@@ -69,27 +70,16 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
...
@@ -69,27 +70,16 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform
::
RecordEvent
e
(
"ScopeBufferedSSAGraphExecutorAfterRun"
,
nullptr
);
platform
::
RecordEvent
e
(
"ScopeBufferedSSAGraphExecutorAfterRun"
,
nullptr
);
drop_scope_counter_
+=
1
;
drop_scope_counter_
+=
1
;
#ifdef PADDLE_WITH_CUDA
const
std
::
string
gc_name
=
"garbage_collector"
;
DeviceGarbageCollectorMap
*
gc
=
Graph
().
Has
(
gc_name
)
?
&
(
Graph
().
Get
<
DeviceGarbageCollectorMap
>
(
gc_name
))
:
nullptr
;
#endif
if
(
!
fetch_tensors
.
empty
()
||
if
(
!
fetch_tensors
.
empty
()
||
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
=
0
;
drop_scope_counter_
=
0
;
// Wait All computational streams
// Wait All computational streams
for
(
auto
p
:
places_
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
])
->
Wait
();
#ifdef PADDLE_WITH_CUDA
if
(
gc_
)
{
if
(
gc
!=
nullptr
&&
platform
::
is_gpu_place
(
p
))
{
(
*
gc_
)[
i
]
->
Wait
();
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
);
(
*
gc_
)[
i
]
->
Reset
();
auto
&
gc_at_place
=
gc
->
at
(
gpu_place
.
device
);
gc_at_place
->
Wait
();
gc_at_place
->
Reset
();
}
}
#endif
}
}
for
(
auto
&
scope
:
local_scopes_
)
{
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
auto
&
local_scope
=
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
096673f6
...
@@ -21,9 +21,11 @@
...
@@ -21,9 +21,11 @@
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
...
@@ -55,6 +57,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -55,6 +57,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
GarbageCollectorList
*
gc_
{
nullptr
};
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/garbage_collector.h
浏览文件 @
096673f6
...
@@ -65,7 +65,7 @@ class GarbageCollector {
...
@@ -65,7 +65,7 @@ class GarbageCollector {
if
(
clear_deque
!=
nullptr
)
{
if
(
clear_deque
!=
nullptr
)
{
callback
();
callback
();
ClearCallback
([
=
]()
{
ClearCallback
([
clear_deque
]()
{
for
(
auto
*
obj
:
*
clear_deque
)
obj
->
clear
();
for
(
auto
*
obj
:
*
clear_deque
)
obj
->
clear
();
});
});
}
}
...
@@ -109,7 +109,6 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> {
...
@@ -109,7 +109,6 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> {
}
}
void
Wait
()
const
override
{
void
Wait
()
const
override
{
this
->
dev_ctx_
->
Wait
();
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
WaitStreamCallback
();
->
WaitStreamCallback
();
}
}
...
@@ -127,14 +126,14 @@ class StreamGarbageCollector : public GarbageCollector<T> {
...
@@ -127,14 +126,14 @@ class StreamGarbageCollector : public GarbageCollector<T> {
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
)
);
platform
::
SetDeviceId
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
}
}
~
StreamGarbageCollector
()
{
~
StreamGarbageCollector
()
{
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
)
);
platform
::
SetDeviceId
(
place
.
device
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
}
...
@@ -148,8 +147,11 @@ class StreamGarbageCollector : public GarbageCollector<T> {
...
@@ -148,8 +147,11 @@ class StreamGarbageCollector : public GarbageCollector<T> {
cudaStream_t
stream
()
const
{
return
stream_
;
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
protected:
protected:
// ClearCallback and Wait()/Reset() cannot be call in multiple threads
// But it is not important, because they would not be called in multiple
// threads
// either in Executor or ParallelExecutor
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
this
->
mutex_
);
callback_manager_
->
AddCallback
(
callback
);
callback_manager_
->
AddCallback
(
callback
);
}
}
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
096673f6
...
@@ -73,14 +73,21 @@ class Graph {
...
@@ -73,14 +73,21 @@ class Graph {
}
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
return
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
()
;
return
attrs_
.
count
(
attr_name
)
>
0
;
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
Has
(
attr_name
),
"%s attr not registered for graph."
,
PADDLE_ENFORCE
(
Has
(
attr_name
),
"%s attr not registered for graph."
,
attr_name
);
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
try
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
catch
(
boost
::
bad_any_cast
&
)
{
PADDLE_THROW
(
"Invalid attribute type of %s error, expected: %s, actual: %s"
,
attr_name
,
typeid
(
AttrType
*
).
name
(),
attrs_
.
at
(
attr_name
).
type
().
name
());
}
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
096673f6
...
@@ -51,11 +51,18 @@ class Pass {
...
@@ -51,11 +51,18 @@ class Pass {
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
"%s attr not registered for pass."
,
attr_name
);
"%s attr not registered for pass."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
try
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
catch
(
boost
::
bad_any_cast
&
)
{
PADDLE_THROW
(
"Invalid attribute type of %s error, expected: %s, actual: %s"
,
attr_name
,
typeid
(
AttrType
*
).
name
(),
attrs_
.
at
(
attr_name
).
type
().
name
());
}
}
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
return
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
()
;
return
attrs_
.
count
(
attr_name
)
>
0
;
}
}
void
Erase
(
const
std
::
string
&
attr_name
)
{
void
Erase
(
const
std
::
string
&
attr_name
)
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
096673f6
...
@@ -26,6 +26,7 @@ limitations under the License. */
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -49,6 +50,15 @@ class ParallelExecutorPrivate {
...
@@ -49,6 +50,15 @@ class ParallelExecutorPrivate {
}
}
}
}
}
}
void
ResetRuntimeReferenceCount
()
{
for
(
size_t
i
=
0
;
i
<
rt_ref_cnts_
.
size
();
++
i
)
{
for
(
auto
&
pair
:
rt_ref_cnts_
[
i
])
{
rt_cur_ref_cnts_
[
i
][
pair
.
first
]
=
pair
.
second
;
}
}
}
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
Scope
*
global_scope_
;
// not owned
Scope
*
global_scope_
;
// not owned
...
@@ -60,6 +70,13 @@ class ParallelExecutorPrivate {
...
@@ -60,6 +70,13 @@ class ParallelExecutorPrivate {
bool
own_local_scope_
;
bool
own_local_scope_
;
bool
use_cuda_
;
bool
use_cuda_
;
bool
use_all_reduce_
;
bool
use_all_reduce_
;
// rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_
std
::
vector
<
details
::
ReferenceCountMap
>
rt_ref_cnts_
;
std
::
vector
<
details
::
AtomicReferenceCountMap
>
rt_cur_ref_cnts_
;
details
::
GarbageCollectorList
gcs_
;
};
};
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
...
@@ -128,35 +145,56 @@ ParallelExecutor::ParallelExecutor(
...
@@ -128,35 +145,56 @@ ParallelExecutor::ParallelExecutor(
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
member_
->
local_scopes_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
if
(
max_memory_size
>=
0
)
{
for
(
auto
&
place
:
member_
->
places_
)
{
size_t
place_num
=
member_
->
places_
.
size
();
if
(
!
platform
::
is_gpu_place
(
place
))
continue
;
for
(
size_t
i
=
0
;
i
<
place_num
;
++
i
)
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
);
auto
&
place
=
member_
->
places_
[
i
];
if
(
gcs_
[
gpu_place
.
device
]
==
nullptr
)
{
#ifdef PADDLE_WITH_CUDA
ref_cnts_
[
gpu_place
.
device
].
reset
(
new
details
::
ReferenceCountMap
());
if
(
platform
::
is_gpu_place
(
place
))
{
cur_ref_cnts_
[
gpu_place
.
device
].
reset
(
member_
->
gcs_
.
emplace_back
(
new
StreamGarbageCollector
<
Tensor
>
(
new
details
::
AtomicReferenceCountMap
());
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
gcs_
[
gpu_place
.
device
].
reset
(
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
new
StreamGarbageCollector
<
Tensor
>
(
gpu_place
,
max_memory_size
));
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
#endif
member_
->
gcs_
.
emplace_back
(
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
));
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
#ifdef PADDLE_WITH_CUDA
}
}
}
#endif
if
(
!
gcs_
.
empty
())
{
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kCurReferenceCount
,
&
cur_ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
graph
->
SetNotOwned
(
"garbage_collector"
,
&
gcs_
);
}
}
}
}
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
if
(
!
member_
->
gcs_
.
empty
())
{
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
std
::
vector
<
details
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
(
member_
->
rt_ref_cnts_
));
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kCurReferenceCount
,
&
(
member_
->
rt_cur_ref_cnts_
));
eager_deletion_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
(
member_
->
gcs_
));
eager_deletion_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
}
// Step 3. Create vars in each scope. Passes may also create new vars.
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
// skip control vars and empty vars
...
@@ -271,18 +309,16 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -271,18 +309,16 @@ void ParallelExecutor::BCastParamsToDevices(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
platform
::
RecordBlock
b
(
0
);
#ifdef PADDLE_WITH_CUDA
if
(
!
member_
->
gcs_
.
empty
())
{
if
(
!
gcs_
.
empty
())
{
member_
->
ResetRuntimeReferenceCount
();
ResetReferenceCount
();
size_t
n
=
member_
->
rt_ref_cnts_
.
size
();
for
(
auto
&
pair
:
cur_ref_cnts_
)
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
auto
&
name_map
=
*
(
pair
.
second
);
for
(
auto
&
fetch_name
:
fetch_tensors
)
{
for
(
auto
&
fetch_name
:
fetch_tensors
)
{
name_map
.
erase
(
fetch_name
);
member_
->
rt_cur_ref_cnts_
[
i
]
.
erase
(
fetch_name
);
}
}
name_map
.
erase
(
fetched_var_name
);
member_
->
rt_cur_ref_cnts_
[
i
]
.
erase
(
fetched_var_name
);
}
}
}
}
#endif
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
fetch_data
;
...
@@ -326,13 +362,11 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -326,13 +362,11 @@ ParallelExecutor::~ParallelExecutor() {
for
(
auto
&
p
:
member_
->
places_
)
{
for
(
auto
&
p
:
member_
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
}
// member_ must be destructed before gcs_ since the destructor of
delete
member_
;
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_
.
reset
();
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#ifdef PADDLE_WITH_CUDA
USE_PASS
(
reference_count_pass
);
USE_PASS
(
reference_count_pass
);
#endif
USE_PASS
(
eager_deletion_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
096673f6
...
@@ -14,7 +14,6 @@ limitations under the License. */
...
@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#pragma once
#include <atomic>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
...
@@ -29,10 +28,6 @@ limitations under the License. */
...
@@ -29,10 +28,6 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -75,24 +70,7 @@ class ParallelExecutor {
...
@@ -75,24 +70,7 @@ class ParallelExecutor {
private:
private:
void
BCastParamsToDevices
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
;
void
BCastParamsToDevices
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
;
std
::
unique_ptr
<
ParallelExecutorPrivate
>
member_
;
ParallelExecutorPrivate
*
member_
;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details
::
DeviceReferenceCountMap
ref_cnts_
;
details
::
AtomicDeviceReferenceCountMap
cur_ref_cnts_
;
details
::
DeviceGarbageCollectorMap
gcs_
;
void
ResetReferenceCount
()
{
for
(
auto
&
pair1
:
ref_cnts_
)
{
for
(
auto
&
pair2
:
*
(
pair1
.
second
))
{
(
*
(
cur_ref_cnts_
[
pair1
.
first
]))[
pair2
.
first
]
=
pair2
.
second
;
}
}
}
#endif
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
096673f6
...
@@ -56,9 +56,16 @@ ELSE()
...
@@ -56,9 +56,16 @@ ELSE()
set
(
MKLDNN_CTX_DEPS
)
set
(
MKLDNN_CTX_DEPS
)
ENDIF
()
ENDIF
()
nv_library
(
stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce
)
IF
(
WITH_GPU
)
set
(
STREAM_CALLBACK_DEPS stream_callback_manager
)
ELSE
()
set
(
STREAM_CALLBACK_DEPS
)
ENDIF
()
# memcpy depends on device_context, here add deps individually for
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
# avoiding cycle dependencies
cc_library
(
device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
cc_library
(
device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
${
STREAM_CALLBACK_DEPS
}
place eigen3 stringpiece cpu_helper cpu_info framework_proto
${
GPU_CTX_DEPS
}
${
MKLDNN_CTX_DEPS
}
)
place eigen3 stringpiece cpu_helper cpu_info framework_proto
${
GPU_CTX_DEPS
}
${
MKLDNN_CTX_DEPS
}
)
nv_test
(
device_context_test SRCS device_context_test.cu DEPS device_context gpu_info
)
nv_test
(
device_context_test SRCS device_context_test.cu DEPS device_context gpu_info
)
...
...
paddle/fluid/platform/stream_callback_manager.cc
0 → 100644
浏览文件 @
096673f6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
struct
StreamCallbackContext
{
inline
StreamCallbackContext
(
const
StreamCallbackManager
*
manager
,
std
::
function
<
void
()
>
callback
)
:
manager_
(
manager
),
callback_
(
std
::
move
(
callback
))
{}
const
StreamCallbackManager
*
manager_
;
// do not own
std
::
function
<
void
()
>
callback_
;
};
StreamCallbackManager
::
StreamCallbackManager
(
const
cudaStream_t
stream
)
:
stream_
(
stream
),
thread_pool_
(
new
::
ThreadPool
(
1
))
{}
void
StreamCallbackManager
::
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
{
auto
*
stream_callback_context
=
new
StreamCallbackContext
(
this
,
std
::
move
(
callback
));
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE
(
cudaLaunchHostFunc
(
stream_
,
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
));
#else
PADDLE_ENFORCE
(
cudaStreamAddCallback
(
stream_
,
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
,
0
));
#endif
}
void
StreamCallbackManager
::
Wait
()
const
{
thread_pool_
.
reset
(
new
::
ThreadPool
(
1
));
}
#if CUDA_VERSION >= 10000
void
CUDART_CB
StreamCallbackManager
::
StreamCallbackFunc
(
void
*
user_data
)
#else
void
CUDART_CB
StreamCallbackManager
::
StreamCallbackFunc
(
cudaStream_t
stream
,
cudaError_t
status
,
void
*
user_data
)
#endif
{
auto
*
callback_context_ptr
=
reinterpret_cast
<
StreamCallbackContext
*>
(
user_data
);
callback_context_ptr
->
manager_
->
thread_pool_
->
enqueue
(
[
callback_context_ptr
]()
{
std
::
unique_ptr
<
StreamCallbackContext
>
callback_context
(
callback_context_ptr
);
callback_context
->
callback_
();
});
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/stream_callback_manager.h
浏览文件 @
096673f6
...
@@ -19,66 +19,29 @@
...
@@ -19,66 +19,29 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <functional>
#include <functional>
#include <memory>
#include <memory>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
class
StreamCallbackManager
;
// NOTE(zjl): clean StreamCallback to make compilation faster
struct
StreamCallbackContext
{
template
<
typename
Callback
>
inline
StreamCallbackContext
(
const
StreamCallbackManager
*
manager
,
Callback
&&
callback
)
:
manager_
(
manager
),
callback_
(
callback
)
{}
const
StreamCallbackManager
*
manager_
;
// do not own
std
::
function
<
void
()
>
callback_
;
};
class
StreamCallbackManager
{
class
StreamCallbackManager
{
public:
public:
explicit
inline
StreamCallbackManager
(
cudaStream_t
stream
=
nullptr
)
explicit
StreamCallbackManager
(
const
cudaStream_t
stream
);
:
stream_
(
stream
),
thread_pool_
(
new
ThreadPool
(
1
))
{}
template
<
typename
Callback
>
void
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
;
inline
void
AddCallback
(
Callback
&&
callback
)
const
{
auto
*
stream_callback_context
=
new
StreamCallbackContext
(
this
,
std
::
forward
<
Callback
>
(
callback
));
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE
(
cudaLaunchHostFunc
(
stream_
,
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
));
// NOLINT
#else
PADDLE_ENFORCE
(
cudaStreamAddCallback
(
stream_
,
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
,
0
));
// NOLINT
#endif
}
void
Wait
()
const
{
thread_pool_
.
reset
(
new
ThreadPool
(
1
));
}
void
Wait
()
const
;
private:
private:
const
cudaStream_t
stream_
;
const
cudaStream_t
stream_
;
mutable
std
::
unique_ptr
<
ThreadPool
>
thread_pool_
;
mutable
std
::
unique_ptr
<
::
ThreadPool
>
thread_pool_
;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
#if CUDA_VERSION >= 10000
#if CUDA_VERSION >= 10000
static
void
CUDART_CB
StreamCallbackFunc
(
void
*
user_data
)
static
void
CUDART_CB
StreamCallbackFunc
(
void
*
user_data
)
;
#else
#else
static
void
CUDART_CB
StreamCallbackFunc
(
cudaStream_t
stream
,
static
void
CUDART_CB
StreamCallbackFunc
(
cudaStream_t
stream
,
cudaError_t
status
,
void
*
user_data
)
cudaError_t
status
,
void
*
user_data
)
;
#endif
#endif
{
auto
*
callback_context_ptr
=
reinterpret_cast
<
StreamCallbackContext
*>
(
user_data
);
callback_context_ptr
->
manager_
->
thread_pool_
->
enqueue
([
=
]()
{
std
::
unique_ptr
<
StreamCallbackContext
>
callback_context
(
callback_context_ptr
);
callback_context
->
callback_
();
});
}
};
};
}
// namespace platform
}
// namespace platform
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录