提交 24ea39c4 编写于 作者: S sneaxiy

feature/eager_delete_tensor

上级 f86198e6
...@@ -29,13 +29,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -29,13 +29,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle if(WITH_GPU)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -33,6 +35,10 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -33,6 +35,10 @@ struct ComputationOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; }
const platform::Place &GetPlace() const { return place_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -82,6 +82,13 @@ class OpHandleBase { ...@@ -82,6 +82,13 @@ class OpHandleBase {
size_t NoDummyInputSize() const; size_t NoDummyInputSize() const;
ir::Node *Node() { return node_; }
const std::map<platform::Place, platform::DeviceContext *>
&GetDeviceContexts() const {
return dev_ctxes_;
}
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
using ReferenceCountMap = std::unordered_map<std::string, int>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<int>>;
using DeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
using AtomicDeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
using DeviceGarbageCollectorMap =
std::unordered_map<int,
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
class ReferenceCountOpHandle : public OpHandleBase {
public:
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
const platform::CUDAPlace &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
}
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
std::string Name() const override { return "reference_count"; }
// protected:
void RunImpl() override {
auto *exec_scope_ = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<LoDTensor *> tensors;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope_->FindVar(name);
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
if (it->second.fetch_sub(1) <= 1) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(const std::vector<LoDTensor *> &tensors) const {
auto *gc = dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
}
const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_;
std::vector<std::string> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
if (!platform::is_gpu_place(var_handle->place_) ||
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
continue;
VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name();
// This is wierd but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
continue;
} else {
if (var_desc->Persistable() ||
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
continue;
}
// compute op only runs in one device
if (ref_cnts[place.device]->count(var_name))
++(*ref_cnts[place.device])[var_name];
else
(*ref_cnts[place.device])[var_name] = 1;
names.insert(var_name);
var_names_in_op.push_back(var_name);
}
return var_names_in_op;
};
auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto var_name = var_handle->Node()->Name();
auto var_place = var_handle->place_;
if (!platform::is_gpu_place(var_place)) continue;
auto place = boost::get<platform::CUDAPlace>(var_place);
if (names.count(var_name) == 0) continue;
if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name];
}
}
};
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
ref_cnt_handle->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
}
for (auto &op : all_ops) {
update_ref_cnts_from_non_compute_op(op, op->Inputs());
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
auto it = compute_ref_cnt_map.find(op.get());
if (it != compute_ref_cnt_map.end()) {
new_all_ops.emplace_back(std::move(op));
new_all_ops.emplace_back(std::unique_ptr<OpHandleBase>(it->second));
} else {
new_all_ops.emplace_back(std::move(op));
}
}
all_ops.swap(new_all_ops);
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGlobalReferenceCount[] = "reference_count";
constexpr char kCurReferenceCount[] = "current_reference_count";
constexpr char kGarbageCollector[] = "garbage_collector";
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -56,12 +60,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -56,12 +60,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
auto fetch_data = underlying_executor_->Run(fetch_tensors); auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1; drop_scope_counter_ += 1;
#ifdef PADDLE_WITH_CUDA
const std::string gc_name = "garbage_collector";
DeviceGarbageCollectorMap *gc =
Graph().Has(gc_name) ? &(Graph().Get<DeviceGarbageCollectorMap>(gc_name))
: nullptr;
#endif
if (!fetch_tensors.empty() || if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
// Wait All computational streams // Wait All computational streams
for (auto p : places_) { for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr && platform::is_gpu_place(p)) {
auto gpu_place = boost::get<platform::CUDAPlace>(p);
auto &gc_at_place = gc->at(gpu_place.device);
gc_at_place->Wait();
gc_at_place->Reset();
}
#endif
} }
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
auto &local_scope = auto &local_scope =
......
...@@ -37,7 +37,9 @@ int kProgramId = -1; ...@@ -37,7 +37,9 @@ int kProgramId = -1;
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id) const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {} : prog_(prog),
block_id_(block_id),
ref_cnts_(GetNonPersistableReferenceCount<int>(prog, block_id)) {}
ExecutorPrepareContext::~ExecutorPrepareContext() { ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(5) << "destroy ExecutorPrepareContext";
...@@ -335,20 +337,84 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -335,20 +337,84 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
CreateVariables(ctx->prog_, local_scope, ctx->block_id_); CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
} }
std::shared_ptr<std::vector<framework::LoDTensor*>> erase_tensors(
new std::vector<framework::LoDTensor*>());
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else {
#endif
gc.reset(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
}
#endif
}
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
// NOTE! Please do not delete this line, it's usefull because the debug
// string before and after op.run are different, after run the output #ifdef PADDLE_WITH_CUDA
// will have right shape which is usefull for debug. if (gc != nullptr) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) {
for (auto& input_name : input.second) {
auto it = ctx->ref_cnts_.find(input_name);
if (it == ctx->ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->ref_cnts_.erase(input_name);
} else {
--(it->second);
}
}
}
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->ref_cnts_.find(output_name);
if (it == ctx->ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->ref_cnts_.erase(output_name);
} else {
--(it->second);
}
}
}
if (!erase_vars.empty()) {
std::vector<framework::LoDTensor*> erase_tensors;
for (auto& name : erase_vars) {
auto* var = local_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
erase_tensors.push_back(tensor);
}
}
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
}
#endif
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
} }
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (gc != nullptr)
gc->Wait();
else
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (local_scope != scope) { if (local_scope != scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} else { } else {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -27,6 +28,48 @@ namespace paddle { ...@@ -27,6 +28,48 @@ namespace paddle {
namespace framework { namespace framework {
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
int64_t GetEagerDeletionThreshold();
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_set<std::string> ignored_vars;
std::unordered_map<std::string, T> ref_cnts;
for (auto var_desc : block.AllVars()) {
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) {
ignored_vars.insert(var_desc->Name()); // ignore persistable vars
}
}
for (auto op_desc : block.AllOps()) {
for (auto& input : op_desc->Inputs()) {
for (auto& input_name : input.second) {
if (!ignored_vars.count(input_name)) {
if (ref_cnts.count(input_name))
++ref_cnts[input_name];
else
ref_cnts[input_name] = 1;
}
}
}
for (auto& output : op_desc->Outputs()) {
for (auto output_name : output.second) {
if (!ignored_vars.count(output_name)) {
if (ref_cnts.count(output_name))
++ref_cnts[output_name];
else
ref_cnts[output_name] = 1;
}
}
}
}
return ref_cnts;
}
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
...@@ -34,6 +77,8 @@ struct ExecutorPrepareContext { ...@@ -34,6 +77,8 @@ struct ExecutorPrepareContext {
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_;
}; };
class Executor { class Executor {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
// T should have memory_size() and clear() method
template <typename T>
class GarbageCollector {
public:
GarbageCollector(const platform::Place &place, size_t max_memory_size)
: max_memory_size_(std::max(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new std::deque<T *>());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
virtual ~GarbageCollector() {}
void Reset() {
std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>());
cur_memory_size_ = 0;
}
template <typename Container>
void Add(const Container &objs) {
Add(objs, []() {});
}
template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) {
std::shared_ptr<std::deque<T *>> clear_deque;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) {
garbages_->push_back(obj);
cur_memory_size_ += obj->memory_size();
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
clear_deque = garbages_;
garbages_.reset(new std::deque<T *>());
}
}
if (clear_deque != nullptr) {
callback();
ClearCallback([=]() {
for (auto *obj : *clear_deque) obj->clear();
});
}
}
virtual void Wait() const {}
protected:
virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_;
std::shared_ptr<std::deque<T *>> garbages_;
mutable std::mutex mutex_;
const size_t max_memory_size_;
size_t cur_memory_size_ = 0;
};
template <typename T>
class CPUGarbageCollector : public GarbageCollector<T> {
public:
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
protected:
void ClearCallback(const std::function<void()> &callback) override {
callback();
}
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class DefaultStreamGarbageCollector : public GarbageCollector<T> {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
cudaStream_t stream() const {
return static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->stream();
}
void Wait() const override {
this->dev_ctx_->Wait();
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
protected:
void ClearCallback(const std::function<void()> &callback) override {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
};
template <typename T>
class StreamGarbageCollector : public GarbageCollector<T> {
public:
StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
void Wait() const override {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->Wait();
}
cudaStream_t stream() const { return stream_; }
protected:
void ClearCallback(const std::function<void()> &callback) override {
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->AddCallback(callback);
}
private:
cudaStream_t stream_;
std::unique_ptr<platform::StreamCallbackManager> callback_manager_;
};
#endif
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* The graph is a Directed Acyclic Single Static Assignment Graph.
*
* In more detail, the following properties must hold:
*
* The graph shouldn't contain cycle. Each node is a black-box to the graph
* so the node itself could be a loop operator.
*
* Each Variable-type node has only one input (thus single static assignment).
*
* The output/input of operator is variable and the output/input of variable
* is operator.
*
* The following data harzards in Program are addressed in the Graph:
*
* Write-After-Read
* a = op1(x)
* x = op2(b)
* A control-dependency connection is created bettwen op1 and op2 such that
* op1->op2, so as to ensure correct order.
*
* Write-After-Write
* x = op1(a)
* x = op2(b)
* A control-dependency connection is created between op1 and op2 such that
* op1->op2, so as to ensure correct order.
*
* Other properties currently hold, but is not enforced yet:
*
* Variable-type node (not control dep) with the same variable name share
* the same underlying VarDesc.
*/
class Graph {
public:
explicit Graph(const ProgramDesc &program);
virtual ~Graph() {
for (auto &attr : attrs_) {
attr_dels_[attr.first]();
}
attrs_.clear();
attr_dels_.clear();
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = []() {};
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc));
}
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc));
}
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir::Node *CreateControlDepVar() {
// TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
}
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type));
}
// Clear all node information of the graph and return the ownership of the
// nodes.
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
std::vector<std::unique_ptr<ir::Node>> ret;
for (auto &n : nodes_) {
ret.emplace_back(n.second.release());
}
nodes_.clear();
node_set_.clear();
return ret;
}
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE low performance, but simple and secure.
Node *RetriveNode(int id) {
for (auto &node : nodes_) {
if (node.second->id() == id) {
return node.second.get();
}
}
return nullptr;
}
private:
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node);
node_set_.insert(node);
return node;
}
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_;
};
bool IsControlDepVar(const ir::Node &var);
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -19,9 +19,15 @@ limitations under the License. */ ...@@ -19,9 +19,15 @@ limitations under the License. */
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
...@@ -115,17 +121,39 @@ ParallelExecutor::ParallelExecutor( ...@@ -115,17 +121,39 @@ ParallelExecutor::ParallelExecutor(
build_strategy); build_strategy);
if (member_->use_cuda_) { if (member_->use_cuda_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get());
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (auto &place : member_->places_) {
if (!platform::is_gpu_place(place)) continue;
auto gpu_place = boost::get<platform::CUDAPlace>(place);
if (gcs_[gpu_place.device] == nullptr) {
ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap());
cur_ref_cnts_[gpu_place.device].reset(
new details::AtomicReferenceCountMap());
gcs_[gpu_place.device].reset(
new StreamGarbageCollector<Tensor>(gpu_place, max_memory_size));
}
}
if (!gcs_.empty()) {
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graph = ref_cnt_pass->Apply(std::move(graph));
graph->SetNotOwned("garbage_collector", &gcs_);
}
}
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
} }
builder_ = builder_factory.Create();
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
...@@ -216,6 +244,11 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -216,6 +244,11 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
platform::RecordBlock b(0); platform::RecordBlock b(0);
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
ResetReferenceCount();
}
#endif
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data; fetch_data;
...@@ -265,3 +298,11 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -265,3 +298,11 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/fluid/framework/details/build_strategy.h> #include <paddle/fluid/framework/details/build_strategy.h>
#include <atomic>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
...@@ -70,7 +72,23 @@ class ParallelExecutor { ...@@ -70,7 +72,23 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details::DeviceReferenceCountMap ref_cnts_;
details::AtomicDeviceReferenceCountMap cur_ref_cnts_;
details::DeviceGarbageCollectorMap gcs_;
void ResetReferenceCount() {
for (auto &pair1 : ref_cnts_) {
for (auto &pair2 : *(pair1.second)) {
(*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second;
}
}
}
#endif
}; };
} // namespace framework } // namespace framework
......
...@@ -31,9 +31,21 @@ DEFINE_bool( ...@@ -31,9 +31,21 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but " "Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)"); "slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_GB, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_GB < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_GB *
(static_cast<int64_t>(1) << 30));
}
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
......
...@@ -26,6 +26,8 @@ limitations under the License. */ ...@@ -26,6 +26,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold();
class Scope; class Scope;
/** /**
......
...@@ -149,6 +149,8 @@ class Tensor { ...@@ -149,6 +149,8 @@ class Tensor {
void set_layout(const DataLayout layout) { layout_ = layout; } void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() { holder_ = nullptr; }
private: private:
/** /**
* @note Placeholder hides type T, so it doesn't appear as a template * @note Placeholder hides type T, so it doesn't appear as a template
......
...@@ -45,8 +45,8 @@ ENDIF() ...@@ -45,8 +45,8 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS malloc cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
place eigen3 stringpiece cpu_helper ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
cc_test(init_test SRCS init_test.cc DEPS device_context) cc_test(init_test SRCS init_test.cc DEPS device_context)
......
...@@ -159,11 +159,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { ...@@ -159,11 +159,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
} else { } else {
cudnn_handle_ = nullptr; cudnn_handle_ = nullptr;
} }
callback_manager_.reset(new StreamCallbackManager(stream_));
} }
CUDADeviceContext::~CUDADeviceContext() { CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cudnn_handle_ != nullptr) { if (cudnn_handle_ != nullptr) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
......
...@@ -31,8 +31,13 @@ limitations under the License. */ ...@@ -31,8 +31,13 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
DECLARE_bool(clear_gpu_memory_when_unused);
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -106,6 +111,17 @@ class CUDADeviceContext : public DeviceContext { ...@@ -106,6 +111,17 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->AddCallback(callback);
}
void WaitStreamCallback() const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->Wait();
}
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -119,7 +135,12 @@ class CUDADeviceContext : public DeviceContext { ...@@ -119,7 +135,12 @@ class CUDADeviceContext : public DeviceContext {
int multi_process; int multi_process;
int max_threads_per_mp; int max_threads_per_mp;
std::mutex mtx_; mutable std::mutex mtx_;
// This lock is only used by callback
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_;
std::unique_ptr<StreamCallbackManager> callback_manager_;
}; };
template <> template <>
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <memory>
#include "ThreadPool.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager;
struct StreamCallbackContext {
template <typename Callback>
inline StreamCallbackContext(const StreamCallbackManager *manager,
Callback &&callback)
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
StreamCallback callback_;
};
class StreamCallbackManager {
public:
explicit inline StreamCallbackManager(cudaStream_t stream = nullptr)
: stream_(stream), thread_pool_(new ThreadPool(1)) {}
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo(
[=](cudaStream_t, cudaError_t) { callback(); });
}
template <typename Callback>
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const {
auto *stream_callback_context = new StreamCallbackContext(this, callback);
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
}
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
private:
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status,
void *user_data) {
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_(stream, status);
});
}
};
} // namespace platform
} // namespace paddle
...@@ -117,9 +117,19 @@ def __bootstrap__(): ...@@ -117,9 +117,19 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads) os.environ['OMP_NUM_THREADS'] = str(num_threads)
read_env_flags = [ read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'use_pinned_memory',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'check_nan_inf',
'init_allocated_mem' 'benchmark',
'warpctc_dir',
'eager_delete_scope',
'use_mkldnn',
'initial_cpu_memory_in_mb',
'init_allocated_mem',
'free_idle_memory',
'paddle_num_threads',
"dist_threadpool_size",
'cpu_deterministic',
'eager_delete_tensor_GB',
] ]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
read_env_flags += [ read_env_flags += [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册