Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4c489180
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4c489180
编写于
9月 17, 2018
作者:
Z
Zeng Jinle
提交者:
GitHub
9月 17, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13411 from sneaxiy/feature/eager_delete_tensor
[WIP] Feature/eager delete tensor
上级
7d00db48
612e1a31
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
807 addition
and
6 deletion
+807
-6
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+11
-1
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+4
-0
paddle/fluid/framework/details/reference_count_op_handle.h
paddle/fluid/framework/details/reference_count_op_handle.h
+123
-0
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+150
-0
paddle/fluid/framework/details/reference_count_pass.h
paddle/fluid/framework/details/reference_count_pass.h
+37
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+19
-0
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+71
-2
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+43
-0
paddle/fluid/framework/garbage_collector.h
paddle/fluid/framework/garbage_collector.h
+163
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+8
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+32
-0
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+23
-0
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+12
-0
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+2
-0
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+2
-0
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+1
-1
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+3
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+20
-1
paddle/fluid/platform/stream_callback_manager.h
paddle/fluid/platform/stream_callback_manager.h
+82
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-1
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
4c489180
...
...
@@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
if
(
WITH_GPU
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
endif
()
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
if
(
WITH_GPU
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
)
else
()
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
endif
()
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
4c489180
...
...
@@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase {
std
::
string
Name
()
const
override
;
const
Scope
*
GetScope
()
const
{
return
scope_
;
}
const
platform
::
Place
&
GetPlace
()
const
{
return
place_
;
}
protected:
void
RunImpl
()
override
;
...
...
paddle/fluid/framework/details/reference_count_op_handle.h
0 → 100644
浏览文件 @
4c489180
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
using
ReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
using
AtomicReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
std
::
atomic
<
int
>>
;
using
DeviceReferenceCountMap
=
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
ReferenceCountMap
>>
;
using
AtomicDeviceReferenceCountMap
=
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
AtomicReferenceCountMap
>>
;
using
DeviceGarbageCollectorMap
=
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
GarbageCollector
<
framework
::
Tensor
>>>
;
class
ReferenceCountOpHandle
:
public
OpHandleBase
{
public:
ReferenceCountOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
CUDAPlace
&
place
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
scope_
(
scope
),
var_names_
(
var_names
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
IsStreamGarabageCollector
())
{
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
}
}
~
ReferenceCountOpHandle
()
{
if
(
IsStreamGarabageCollector
())
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
PADDLE_ENFORCE
(
cudaSetDevice
(
gpu_place
.
device
));
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
}
}
std
::
string
Name
()
const
override
{
return
"reference_count"
;
}
protected:
void
RunImpl
()
override
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
LoDTensor
*>
tensors
;
for
(
auto
&
name
:
var_names_
)
{
auto
it
=
ref_cnts_
->
find
(
name
);
if
(
it
==
ref_cnts_
->
end
())
continue
;
auto
*
var
=
exec_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
||
!
var
->
IsType
<
LoDTensor
>
())
continue
;
if
(
it
->
second
.
fetch_sub
(
1
)
<=
1
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
if
(
!
tensors
.
empty
())
{
ClearTensors
(
tensors
);
}
}
private:
void
ClearTensors
(
const
std
::
vector
<
LoDTensor
*>
&
tensors
)
{
auto
*
gc
=
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
);
if
(
gc
!=
nullptr
)
{
auto
compute_stream
=
dev_ctx_
->
stream
();
auto
callback_stream
=
gc
->
stream
();
auto
callback_func
=
[
=
]()
{
PADDLE_ENFORCE
(
cudaEventRecord
(
event_
,
compute_stream
));
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
callback_stream
,
event_
,
0
));
};
gc_
->
Add
(
tensors
,
callback_func
);
}
else
{
gc_
->
Add
(
tensors
);
}
}
bool
IsStreamGarabageCollector
()
const
{
return
dynamic_cast
<
const
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
)
!=
nullptr
;
}
const
Scope
*
scope_
;
platform
::
CUDADeviceContext
*
dev_ctx_
;
std
::
vector
<
std
::
string
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
cudaEvent_t
event_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/reference_count_pass.cc
0 → 100644
浏览文件 @
4c489180
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
auto
&
cur_ref_cnts
=
Get
<
AtomicDeviceReferenceCountMap
>
(
kCurReferenceCount
);
auto
&
gcs
=
Get
<
DeviceGarbageCollectorMap
>
(
kGarbageCollector
);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std
::
unordered_set
<
std
::
string
>
names
;
auto
get_ref_cnts_from_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
std
::
vector
<
std
::
string
>
var_names_in_op
;
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
());
if
(
compute_op
==
nullptr
||
!
platform
::
is_gpu_place
(
compute_op
->
GetPlace
()))
return
var_names_in_op
;
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
if
(
!
platform
::
is_gpu_place
(
var_handle
->
place_
)
||
boost
::
get
<
platform
::
CUDAPlace
>
(
var_handle
->
place_
)
!=
place
)
continue
;
VarDesc
*
var_desc
=
var_handle
->
Node
()
->
Var
();
auto
var_name
=
var_handle
->
Node
()
->
Name
();
// This is wierd but there is really some variables without var_desc
// in computation_op
if
(
var_desc
==
nullptr
)
{
if
(
compute_op
->
Node
()
->
Op
()
->
Block
()
->
FindVar
(
var_name
)
==
nullptr
)
continue
;
}
else
{
if
(
var_desc
->
Persistable
()
||
var_desc
->
Proto
()
->
type
().
type
()
!=
proto
::
VarType
::
LOD_TENSOR
)
continue
;
}
// compute op only runs in one device
if
(
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
else
(
*
ref_cnts
[
place
.
device
])[
var_name
]
=
1
;
names
.
insert
(
var_name
);
var_names_in_op
.
push_back
(
var_name
);
}
return
var_names_in_op
;
};
auto
update_ref_cnts_from_non_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
if
(
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
())
!=
nullptr
)
return
;
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
auto
var_name
=
var_handle
->
Node
()
->
Name
();
auto
var_place
=
var_handle
->
place_
;
if
(
!
platform
::
is_gpu_place
(
var_place
))
continue
;
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
var_place
);
if
(
names
.
count
(
var_name
)
==
0
)
continue
;
if
(
ref_cnts
.
count
(
place
.
device
)
&&
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
{
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
}
}
};
std
::
unordered_map
<
OpHandleBase
*
,
ReferenceCountOpHandle
*>
compute_ref_cnt_map
;
auto
&
all_ops
=
graph
->
Get
<
GraphOps
>
(
kGraphOps
);
for
(
auto
&
op
:
all_ops
)
{
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
auto
out_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Outputs
());
if
(
in_var_names
.
empty
()
&&
out_var_names
.
empty
())
continue
;
in_var_names
.
insert
(
in_var_names
.
end
(),
out_var_names
.
begin
(),
out_var_names
.
end
());
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
());
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
compute_op
->
AddOutput
(
dep_var
);
ref_cnt_handle
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
compute_ref_cnt_map
[
compute_op
]
=
ref_cnt_handle
;
}
for
(
auto
&
op
:
all_ops
)
{
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Inputs
());
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Outputs
());
}
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
new_all_ops
;
new_all_ops
.
reserve
(
compute_ref_cnt_map
.
size
()
+
all_ops
.
size
());
for
(
auto
&
op
:
all_ops
)
{
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
().
get
());
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
new_all_ops
.
emplace_back
(
it
->
second
);
}
}
all_ops
.
swap
(
new_all_ops
);
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
reference_count_pass
,
paddle
::
framework
::
details
::
ReferenceCountPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGlobalReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kCurReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/reference_count_pass.h
0 → 100644
浏览文件 @
4c489180
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
constexpr
char
kGlobalReferenceCount
[]
=
"reference_count"
;
constexpr
char
kCurReferenceCount
[]
=
"current_reference_count"
;
constexpr
char
kGarbageCollector
[]
=
"garbage_collector"
;
class
ReferenceCountPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
4c489180
...
...
@@ -18,6 +18,9 @@
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace
paddle
{
namespace
framework
{
...
...
@@ -65,12 +68,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform
::
RecordEvent
e
(
"ScopeBufferedSSAGraphExecutorAfterRun"
,
nullptr
);
drop_scope_counter_
+=
1
;
#ifdef PADDLE_WITH_CUDA
const
std
::
string
gc_name
=
"garbage_collector"
;
DeviceGarbageCollectorMap
*
gc
=
Graph
().
Has
(
gc_name
)
?
&
(
Graph
().
Get
<
DeviceGarbageCollectorMap
>
(
gc_name
))
:
nullptr
;
#endif
if
(
!
fetch_tensors
.
empty
()
||
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
=
0
;
// Wait All computational streams
for
(
auto
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
#ifdef PADDLE_WITH_CUDA
if
(
gc
!=
nullptr
&&
platform
::
is_gpu_place
(
p
))
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
);
auto
&
gc_at_place
=
gc
->
at
(
gpu_place
.
device
);
gc_at_place
->
Wait
();
gc_at_place
->
Reset
();
}
#endif
}
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
4c489180
...
...
@@ -37,7 +37,11 @@ int kProgramId = -1;
ExecutorPrepareContext
::
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
)
:
prog_
(
prog
),
block_id_
(
block_id
)
{}
:
prog_
(
prog
),
block_id_
(
block_id
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
)
{
ref_cnts_
=
GetNonPersistableReferenceCount
<
int
>
(
prog_
,
block_id_
);
}
}
ExecutorPrepareContext
::~
ExecutorPrepareContext
()
{
VLOG
(
5
)
<<
"destroy ExecutorPrepareContext"
;
...
...
@@ -329,15 +333,80 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
CreateVariables
(
ctx
->
prog_
,
local_scope
,
ctx
->
block_id_
);
}
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>
gc
;
if
(
max_memory_size
>=
0
)
{
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place_
))
{
gc
.
reset
(
new
DefaultStreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
max_memory_size
));
}
else
{
#endif
gc
.
reset
(
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place_
),
max_memory_size
));
#ifdef PADDLE_WITH_CUDA
}
#endif
}
for
(
auto
&
op
:
ctx
->
ops_
)
{
op
->
Run
(
*
local_scope
,
place_
);
if
(
gc
!=
nullptr
)
{
std
::
vector
<
std
::
string
>
erase_vars
;
for
(
auto
&
input
:
op
->
Inputs
())
{
for
(
auto
&
input_name
:
input
.
second
)
{
auto
it
=
ctx
->
ref_cnts_
.
find
(
input_name
);
if
(
it
==
ctx
->
ref_cnts_
.
end
())
continue
;
if
(
it
->
second
==
1
)
{
// should delete it
erase_vars
.
emplace_back
(
input_name
);
ctx
->
ref_cnts_
.
erase
(
input_name
);
}
else
{
--
(
it
->
second
);
}
}
}
for
(
auto
&
output
:
op
->
Outputs
())
{
for
(
auto
&
output_name
:
output
.
second
)
{
auto
it
=
ctx
->
ref_cnts_
.
find
(
output_name
);
if
(
it
==
ctx
->
ref_cnts_
.
end
())
continue
;
if
(
it
->
second
==
1
)
{
erase_vars
.
emplace_back
(
output_name
);
ctx
->
ref_cnts_
.
erase
(
output_name
);
}
else
{
--
(
it
->
second
);
}
}
}
if
(
!
erase_vars
.
empty
())
{
std
::
vector
<
framework
::
LoDTensor
*>
erase_tensors
;
for
(
auto
&
name
:
erase_vars
)
{
auto
*
var
=
local_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
erase_tensors
.
push_back
(
tensor
);
}
}
if
(
!
erase_tensors
.
empty
())
gc
->
Add
(
erase_tensors
);
}
}
if
(
FLAGS_benchmark
)
{
VLOG
(
2
)
<<
"Memory used after operator "
+
op
->
Type
()
+
" running: "
<<
memory
::
memory_usage
(
place_
);
}
}
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
)
->
Wait
();
if
(
gc
!=
nullptr
)
{
gc
->
Wait
();
}
else
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
)
->
Wait
();
}
if
(
local_scope
!=
scope
)
{
scope
->
DeleteScope
(
local_scope
);
}
else
{
...
...
paddle/fluid/framework/executor.h
浏览文件 @
4c489180
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -27,6 +28,46 @@ namespace paddle {
namespace
framework
{
extern
void
InitializeVariable
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
template
<
typename
T
>
std
::
unordered_map
<
std
::
string
,
T
>
GetNonPersistableReferenceCount
(
const
ProgramDesc
&
prog
,
size_t
block_id
)
{
auto
&
block
=
prog
.
Block
(
block_id
);
std
::
unordered_set
<
std
::
string
>
ignored_vars
;
std
::
unordered_map
<
std
::
string
,
T
>
ref_cnts
;
for
(
auto
var_desc
:
block
.
AllVars
())
{
auto
type
=
var_desc
->
Proto
()
->
type
().
type
();
if
(
type
!=
proto
::
VarType
::
LOD_TENSOR
||
var_desc
->
Persistable
())
{
ignored_vars
.
insert
(
var_desc
->
Name
());
// ignore persistable vars
}
}
for
(
auto
op_desc
:
block
.
AllOps
())
{
for
(
auto
&
input
:
op_desc
->
Inputs
())
{
for
(
auto
&
input_name
:
input
.
second
)
{
if
(
!
ignored_vars
.
count
(
input_name
))
{
if
(
ref_cnts
.
count
(
input_name
))
++
ref_cnts
[
input_name
];
else
ref_cnts
[
input_name
]
=
1
;
}
}
}
for
(
auto
&
output
:
op_desc
->
Outputs
())
{
for
(
auto
output_name
:
output
.
second
)
{
if
(
!
ignored_vars
.
count
(
output_name
))
{
if
(
ref_cnts
.
count
(
output_name
))
++
ref_cnts
[
output_name
];
else
ref_cnts
[
output_name
]
=
1
;
}
}
}
}
return
ref_cnts
;
}
struct
ExecutorPrepareContext
{
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
);
~
ExecutorPrepareContext
();
...
...
@@ -34,6 +75,8 @@ struct ExecutorPrepareContext {
const
framework
::
ProgramDesc
&
prog_
;
size_t
block_id_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
int
>
ref_cnts_
;
};
class
Executor
{
...
...
paddle/fluid/framework/garbage_collector.h
0 → 100644
浏览文件 @
4c489180
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
framework
{
// T should have memory_size() and clear() method
template
<
typename
T
>
class
GarbageCollector
{
public:
GarbageCollector
(
const
platform
::
Place
&
place
,
size_t
max_memory_size
)
:
max_memory_size_
(
std
::
max
(
max_memory_size
,
static_cast
<
size_t
>
(
1
)))
{
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
}
virtual
~
GarbageCollector
()
{}
void
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
cur_memory_size_
=
0
;
}
template
<
typename
Container
>
void
Add
(
const
Container
&
objs
)
{
Add
(
objs
,
[]()
{});
}
template
<
typename
Container
,
typename
Callback
>
void
Add
(
const
Container
&
objs
,
Callback
&&
callback
)
{
std
::
shared_ptr
<
std
::
deque
<
T
*>>
clear_deque
;
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
for
(
auto
*
obj
:
objs
)
{
garbages_
->
push_back
(
obj
);
cur_memory_size_
+=
obj
->
memory_size
();
}
if
(
cur_memory_size_
>=
max_memory_size_
)
{
cur_memory_size_
=
0
;
clear_deque
=
garbages_
;
garbages_
.
reset
(
new
std
::
deque
<
T
*>
());
}
}
if
(
clear_deque
!=
nullptr
)
{
callback
();
ClearCallback
([
=
]()
{
for
(
auto
*
obj
:
*
clear_deque
)
obj
->
clear
();
});
}
}
virtual
void
Wait
()
const
{}
protected:
virtual
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
=
0
;
platform
::
DeviceContext
*
dev_ctx_
;
std
::
shared_ptr
<
std
::
deque
<
T
*>>
garbages_
;
mutable
std
::
mutex
mutex_
;
const
size_t
max_memory_size_
;
size_t
cur_memory_size_
=
0
;
};
template
<
typename
T
>
class
CPUGarbageCollector
:
public
GarbageCollector
<
T
>
{
public:
CPUGarbageCollector
(
const
platform
::
CPUPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{}
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
callback
();
}
};
#ifdef PADDLE_WITH_CUDA
template
<
typename
T
>
class
DefaultStreamGarbageCollector
:
public
GarbageCollector
<
T
>
{
public:
DefaultStreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{}
cudaStream_t
stream
()
const
{
return
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
stream
();
}
void
Wait
()
const
override
{
this
->
dev_ctx_
->
Wait
();
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
WaitStreamCallback
();
}
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
)
->
AddStreamCallback
(
callback
);
}
};
template
<
typename
T
>
class
StreamGarbageCollector
:
public
GarbageCollector
<
T
>
{
public:
StreamGarbageCollector
(
const
platform
::
CUDAPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
<
T
>
(
place
,
max_memory_size
)
{
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
));
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
callback_manager_
.
reset
(
new
platform
::
StreamCallbackManager
(
stream_
));
}
~
StreamGarbageCollector
()
{
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
this
->
dev_ctx_
->
GetPlace
());
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
void
Wait
()
const
override
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
std
::
lock_guard
<
std
::
mutex
>
guard
(
this
->
mutex_
);
callback_manager_
->
Wait
();
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
this
->
mutex_
);
callback_manager_
->
AddCallback
(
callback
);
}
private:
cudaStream_t
stream_
;
std
::
unique_ptr
<
platform
::
StreamCallbackManager
>
callback_manager_
;
};
#endif
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph.h
浏览文件 @
4c489180
...
...
@@ -94,6 +94,14 @@ class Graph {
};
}
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the graph"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[]()
{};
}
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
{
return
node_set_
;
}
// Create a normal variable with non-null VarDesc.
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
4c489180
...
...
@@ -188,6 +188,30 @@ ParallelExecutor::ParallelExecutor(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
,
member_
->
nccl_ctxs_
.
get
());
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
for
(
auto
&
place
:
member_
->
places_
)
{
if
(
!
platform
::
is_gpu_place
(
place
))
continue
;
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
);
if
(
gcs_
[
gpu_place
.
device
]
==
nullptr
)
{
ref_cnts_
[
gpu_place
.
device
].
reset
(
new
details
::
ReferenceCountMap
());
cur_ref_cnts_
[
gpu_place
.
device
].
reset
(
new
details
::
AtomicReferenceCountMap
());
gcs_
[
gpu_place
.
device
].
reset
(
new
StreamGarbageCollector
<
Tensor
>
(
gpu_place
,
max_memory_size
));
}
}
if
(
!
gcs_
.
empty
())
{
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kCurReferenceCount
,
&
cur_ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
graph
->
SetNotOwned
(
"garbage_collector"
,
&
gcs_
);
}
}
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
...
...
@@ -310,6 +334,11 @@ void ParallelExecutor::BCastParamsToDevices(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
#ifdef PADDLE_WITH_CUDA
if
(
!
gcs_
.
empty
())
{
ResetReferenceCount
();
}
#endif
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
...
...
@@ -367,3 +396,6 @@ USE_PASS(graph_viz_pass);
USE_PASS
(
multi_devices_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_print_pass
);
#ifdef PADDLE_WITH_CUDA
USE_PASS
(
reference_count_pass
);
#endif
paddle/fluid/framework/parallel_executor.h
浏览文件 @
4c489180
...
...
@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once
#include <paddle/fluid/framework/details/build_strategy.h>
#include <atomic>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
...
...
@@ -27,6 +29,10 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace
paddle
{
namespace
framework
{
...
...
@@ -70,6 +76,23 @@ class ParallelExecutor {
private:
ParallelExecutorPrivate
*
member_
;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details
::
DeviceReferenceCountMap
ref_cnts_
;
details
::
AtomicDeviceReferenceCountMap
cur_ref_cnts_
;
details
::
DeviceGarbageCollectorMap
gcs_
;
void
ResetReferenceCount
()
{
for
(
auto
&
pair1
:
ref_cnts_
)
{
for
(
auto
&
pair2
:
*
(
pair1
.
second
))
{
(
*
(
cur_ref_cnts_
[
pair1
.
first
]))[
pair2
.
first
]
=
pair2
.
second
;
}
}
}
#endif
};
}
// namespace framework
...
...
paddle/fluid/framework/scope.cc
浏览文件 @
4c489180
...
...
@@ -31,9 +31,21 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)"
);
DEFINE_double
(
eager_delete_tensor_gb
,
-
1.0
,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"
);
namespace
paddle
{
namespace
framework
{
int64_t
GetEagerDeletionThreshold
()
{
return
FLAGS_eager_delete_tensor_gb
<
0
?
-
1
:
static_cast
<
int64_t
>
(
FLAGS_eager_delete_tensor_gb
*
(
static_cast
<
int64_t
>
(
1
)
<<
30
));
}
Scope
::~
Scope
()
{
DropKids
();
}
Scope
&
Scope
::
NewScope
()
const
{
...
...
paddle/fluid/framework/scope.h
浏览文件 @
4c489180
...
...
@@ -26,6 +26,8 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
int64_t
GetEagerDeletionThreshold
();
class
Scope
;
/**
...
...
paddle/fluid/framework/tensor.h
浏览文件 @
4c489180
...
...
@@ -151,6 +151,8 @@ class Tensor {
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
void
clear
()
{
holder_
=
nullptr
;
}
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
4c489180
...
...
@@ -51,7 +51,7 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library
(
device_context SRCS device_context.cc init.cc DEPS malloc
cc_library
(
device_context SRCS device_context.cc init.cc DEPS
simple_threadpool
malloc
place eigen3 stringpiece cpu_helper cpu_info framework_proto
${
GPU_CTX_DEPS
}
${
MKLDNN_CTX_DEPS
}
)
nv_test
(
device_context_test SRCS device_context_test.cu DEPS device_context gpu_info
)
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
4c489180
...
...
@@ -210,11 +210,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
callback_manager_
.
reset
(
new
StreamCallbackManager
(
stream_
));
}
CUDADeviceContext
::~
CUDADeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
WaitStreamCallback
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
4c489180
...
...
@@ -31,6 +31,9 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
...
...
@@ -112,6 +115,17 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
template
<
typename
Callback
>
void
AddStreamCallback
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
callback_mtx_
);
callback_manager_
->
AddCallback
(
callback
);
}
void
WaitStreamCallback
()
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
callback_mtx_
);
callback_manager_
->
Wait
();
}
private:
CUDAPlace
place_
;
...
...
@@ -125,7 +139,12 @@ class CUDADeviceContext : public DeviceContext {
int
multi_process
;
int
max_threads_per_mp
;
std
::
mutex
mtx_
;
mutable
std
::
mutex
mtx_
;
// This lock is only used by callback
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable
std
::
mutex
callback_mtx_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
};
template
<
>
...
...
paddle/fluid/platform/stream_callback_manager.h
0 → 100644
浏览文件 @
4c489180
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <memory>
#include "ThreadPool.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
using
StreamCallback
=
std
::
function
<
void
(
cudaStream_t
,
cudaError_t
)
>
;
class
StreamCallbackManager
;
struct
StreamCallbackContext
{
template
<
typename
Callback
>
inline
StreamCallbackContext
(
const
StreamCallbackManager
*
manager
,
Callback
&&
callback
)
:
manager_
(
manager
),
callback_
(
callback
)
{}
const
StreamCallbackManager
*
manager_
;
// do not own
StreamCallback
callback_
;
};
class
StreamCallbackManager
{
public:
explicit
inline
StreamCallbackManager
(
cudaStream_t
stream
=
nullptr
)
:
stream_
(
stream
),
thread_pool_
(
new
ThreadPool
(
1
))
{}
template
<
typename
Callback
>
inline
void
AddCallback
(
Callback
&&
callback
)
const
{
AddCallbackWithStreamAndErrorInfo
(
[
=
](
cudaStream_t
,
cudaError_t
)
{
callback
();
});
}
template
<
typename
Callback
>
inline
void
AddCallbackWithStreamAndErrorInfo
(
Callback
&&
callback
)
const
{
auto
*
stream_callback_context
=
new
StreamCallbackContext
(
this
,
callback
);
PADDLE_ENFORCE
(
cudaStreamAddCallback
(
stream_
,
StreamCallbackManager
::
StreamCallbackFunc
,
stream_callback_context
,
0
));
}
void
Wait
()
const
{
thread_pool_
.
reset
(
new
ThreadPool
(
1
));
}
private:
const
cudaStream_t
stream_
;
mutable
std
::
unique_ptr
<
ThreadPool
>
thread_pool_
;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
static
void
CUDART_CB
StreamCallbackFunc
(
cudaStream_t
stream
,
cudaError_t
status
,
void
*
user_data
)
{
auto
*
callback_context_ptr
=
reinterpret_cast
<
StreamCallbackContext
*>
(
user_data
);
callback_context_ptr
->
manager_
->
thread_pool_
->
enqueue
([
=
]()
{
std
::
unique_ptr
<
StreamCallbackContext
>
callback_context
(
callback_context_ptr
);
callback_context
->
callback_
(
stream
,
status
);
});
}
};
}
// namespace platform
}
// namespace paddle
python/paddle/fluid/__init__.py
浏览文件 @
4c489180
...
...
@@ -122,7 +122,7 @@ def __bootstrap__():
'use_pinned_memory'
,
'check_nan_inf'
,
'benchmark'
,
'warpctc_dir'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
,
'free_idle_memory'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'cpu_deterministic'
"dist_threadpool_size"
,
'cpu_deterministic'
,
'eager_delete_tensor_gb'
]
if
core
.
is_compiled_with_dist
():
read_env_flags
.
append
(
'rpc_deadline'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录