diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index abd5459f6d47da6d1341284916b419325dc5977c..a8e0c4a3fedfd56e38de7568be6b3f2e76a4b25f 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) +if(WITH_GPU) + cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle + all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) +endif() + cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) +if(WITH_GPU) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) +else() + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) +endif() + cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index d9fcd92427ef38b131b4ce782c0ada37765682db..e98f1ab148db083ac63a1afd43e334fbfae62539 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase { std::string Name() const override; + const Scope *GetScope() const { return scope_; } + + const platform::Place &GetPlace() const { return place_; } + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..71db8d952f4c205b875ad254dc19c0c1f74e61b3 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -0,0 +1,123 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace framework { +namespace details { + +using ReferenceCountMap = std::unordered_map; +using AtomicReferenceCountMap = + std::unordered_map>; +using DeviceReferenceCountMap = + std::unordered_map>; +using AtomicDeviceReferenceCountMap = + std::unordered_map>; +using DeviceGarbageCollectorMap = + std::unordered_map>>; + +class ReferenceCountOpHandle : public OpHandleBase { + public: + ReferenceCountOpHandle(ir::Node *node, const Scope *scope, + const platform::CUDAPlace &place, + const std::vector &var_names, + GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts) + : OpHandleBase(node), + scope_(scope), + var_names_(var_names), + gc_(gc), + ref_cnts_(ref_cnts) { + dev_ctx_ = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + if (IsStreamGarabageCollector()) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + } + + ~ReferenceCountOpHandle() { + if (IsStreamGarabageCollector()) { + auto gpu_place = boost::get(dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); + PADDLE_ENFORCE(cudaEventDestroy(event_)); + } + } + + std::string Name() const override { return "reference_count"; } + + protected: + void RunImpl() override { + auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); + std::vector tensors; + for (auto &name : var_names_) { + auto it = ref_cnts_->find(name); + if (it == ref_cnts_->end()) continue; + + auto *var = exec_scope->FindVar(name); + if (var == nullptr || !var->IsType()) continue; + + if (it->second.fetch_sub(1) <= 1) { + tensors.emplace_back(var->GetMutable()); + } + } + + if (!tensors.empty()) { + ClearTensors(tensors); + } + } + + private: + void ClearTensors(const std::vector &tensors) { + auto *gc = dynamic_cast *>(gc_); + if (gc != nullptr) { + auto compute_stream = dev_ctx_->stream(); + auto callback_stream = gc->stream(); + auto callback_func = [=]() { + PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); + PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); + }; + gc_->Add(tensors, callback_func); + } else { + gc_->Add(tensors); + } + } + + bool IsStreamGarabageCollector() const { + return dynamic_cast *>(gc_) != nullptr; + } + + const Scope *scope_; + platform::CUDADeviceContext *dev_ctx_; + std::vector var_names_; + GarbageCollector *gc_; // not own + AtomicReferenceCountMap *ref_cnts_; // not own + cudaEvent_t event_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..344754d5a1e119c04cae08ad50126924b5824315 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/reference_count_pass.h" + +namespace paddle { +namespace framework { +namespace details { + +std::unique_ptr ReferenceCountPass::ApplyImpl( + std::unique_ptr graph) const { + auto &ref_cnts = Get(kGlobalReferenceCount); + auto &cur_ref_cnts = Get(kCurReferenceCount); + auto &gcs = Get(kGarbageCollector); + + // It is not easy to find the right reference counts of varaibles in graph + // Step 1: Find all variables in computation ops + // Step 2: Find all variables in non-computation ops which refers to variables + // in computation ops + std::unordered_set names; + auto get_ref_cnts_from_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + std::vector var_names_in_op; + auto *compute_op = dynamic_cast(op.get()); + if (compute_op == nullptr || + !platform::is_gpu_place(compute_op->GetPlace())) + return var_names_in_op; + auto place = boost::get(compute_op->GetPlace()); + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + if (!platform::is_gpu_place(var_handle->place_) || + boost::get(var_handle->place_) != place) + continue; + + VarDesc *var_desc = var_handle->Node()->Var(); + auto var_name = var_handle->Node()->Name(); + + // This is wierd but there is really some variables without var_desc + // in computation_op + if (var_desc == nullptr) { + if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) + continue; + } else { + if (var_desc->Persistable() || + var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) + continue; + } + + // compute op only runs in one device + if (ref_cnts[place.device]->count(var_name)) + ++(*ref_cnts[place.device])[var_name]; + else + (*ref_cnts[place.device])[var_name] = 1; + + names.insert(var_name); + var_names_in_op.push_back(var_name); + } + return var_names_in_op; + }; + + auto update_ref_cnts_from_non_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + if (dynamic_cast(op.get()) != nullptr) return; + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + auto var_name = var_handle->Node()->Name(); + auto var_place = var_handle->place_; + if (!platform::is_gpu_place(var_place)) continue; + auto place = boost::get(var_place); + if (names.count(var_name) == 0) continue; + if (ref_cnts.count(place.device) && + ref_cnts[place.device]->count(var_name)) { + ++(*ref_cnts[place.device])[var_name]; + } + } + }; + + std::unordered_map + compute_ref_cnt_map; + auto &all_ops = graph->Get(kGraphOps); + for (auto &op : all_ops) { + auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); + auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); + if (in_var_names.empty() && out_var_names.empty()) continue; + in_var_names.insert(in_var_names.end(), out_var_names.begin(), + out_var_names.end()); + auto *compute_op = dynamic_cast(op.get()); + auto place = boost::get(compute_op->GetPlace()); + ir::Node *ref_cnt_node = + graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); + auto *ref_cnt_handle = new ReferenceCountOpHandle( + ref_cnt_node, compute_op->GetScope(), place, in_var_names, + gcs[place.device].get(), cur_ref_cnts[place.device].get()); + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + compute_op->AddOutput(dep_var); + ref_cnt_handle->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + compute_ref_cnt_map[compute_op] = ref_cnt_handle; + } + + for (auto &op : all_ops) { + update_ref_cnts_from_non_compute_op(op, op->Inputs()); + update_ref_cnts_from_non_compute_op(op, op->Outputs()); + } + + std::vector> new_all_ops; + new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); + for (auto &op : all_ops) { + new_all_ops.emplace_back(std::move(op)); + auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); + if (it != compute_ref_cnt_map.end()) { + new_all_ops.emplace_back(it->second); + } + } + + all_ops.swap(new_all_ops); + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reference_count_pass, + paddle::framework::details::ReferenceCountPass) + .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) + .RequirePassAttr(paddle::framework::details::kCurReferenceCount) + .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7081280b0600b9c1985987d02d679c298ad4b8bd --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +constexpr char kGlobalReferenceCount[] = "reference_count"; +constexpr char kCurReferenceCount[] = "current_reference_count"; +constexpr char kGarbageCollector[] = "garbage_collector"; + +class ReferenceCountPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 5bd974d6b789a2f085c0a69de5e133187342f587..e5b1eaa7318aecde1dbf89de8fe242a3008db97c 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -18,6 +18,9 @@ #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#endif namespace paddle { namespace framework { @@ -65,12 +68,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); drop_scope_counter_ += 1; + +#ifdef PADDLE_WITH_CUDA + const std::string gc_name = "garbage_collector"; + DeviceGarbageCollectorMap *gc = + Graph().Has(gc_name) ? &(Graph().Get(gc_name)) + : nullptr; +#endif + if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams for (auto p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); +#ifdef PADDLE_WITH_CUDA + if (gc != nullptr && platform::is_gpu_place(p)) { + auto gpu_place = boost::get(p); + auto &gc_at_place = gc->at(gpu_place.device); + gc_at_place->Wait(); + gc_at_place->Reset(); + } +#endif } for (auto &scope : local_scopes_) { auto &local_scope = diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index dad170ed78c64202b5c812bd8682887fe3b736d6..650d9086d423cc62de571fc9c83f4d045ed939c1 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,7 +37,11 @@ int kProgramId = -1; ExecutorPrepareContext::ExecutorPrepareContext( const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), block_id_(block_id) {} + : prog_(prog), block_id_(block_id) { + if (GetEagerDeletionThreshold() >= 0) { + ref_cnts_ = GetNonPersistableReferenceCount(prog_, block_id_); + } +} ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; @@ -329,15 +333,80 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, CreateVariables(ctx->prog_, local_scope, ctx->block_id_); } + int64_t max_memory_size = GetEagerDeletionThreshold(); + + std::unique_ptr> gc; + if (max_memory_size >= 0) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place_)) { + gc.reset(new DefaultStreamGarbageCollector( + boost::get(place_), max_memory_size)); + } else { +#endif + gc.reset(new CPUGarbageCollector( + boost::get(place_), max_memory_size)); +#ifdef PADDLE_WITH_CUDA + } +#endif + } + for (auto& op : ctx->ops_) { op->Run(*local_scope, place_); + if (gc != nullptr) { + std::vector erase_vars; + for (auto& input : op->Inputs()) { + for (auto& input_name : input.second) { + auto it = ctx->ref_cnts_.find(input_name); + if (it == ctx->ref_cnts_.end()) continue; + if (it->second == 1) { // should delete it + erase_vars.emplace_back(input_name); + ctx->ref_cnts_.erase(input_name); + } else { + --(it->second); + } + } + } + + for (auto& output : op->Outputs()) { + for (auto& output_name : output.second) { + auto it = ctx->ref_cnts_.find(output_name); + if (it == ctx->ref_cnts_.end()) continue; + if (it->second == 1) { + erase_vars.emplace_back(output_name); + ctx->ref_cnts_.erase(output_name); + } else { + --(it->second); + } + } + } + + if (!erase_vars.empty()) { + std::vector erase_tensors; + for (auto& name : erase_vars) { + auto* var = local_scope->FindVar(name); + if (var == nullptr) continue; + if (var->IsType()) { + auto* tensor = var->GetMutable(); + erase_tensors.push_back(tensor); + } + } + if (!erase_tensors.empty()) gc->Add(erase_tensors); + } + } + if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " << memory::memory_usage(place_); } } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + + if (gc != nullptr) { + gc->Wait(); + } else { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } + if (local_scope != scope) { scope->DeleteScope(local_scope); } else { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index f95808c199b9de693ec653c29374c9130be7fd59..b746268760570c56c720c6e3b8fe04f8e3f75b4e 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -27,6 +28,46 @@ namespace paddle { namespace framework { extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); +template +std::unordered_map GetNonPersistableReferenceCount( + const ProgramDesc& prog, size_t block_id) { + auto& block = prog.Block(block_id); + std::unordered_set ignored_vars; + std::unordered_map ref_cnts; + + for (auto var_desc : block.AllVars()) { + auto type = var_desc->Proto()->type().type(); + if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { + ignored_vars.insert(var_desc->Name()); // ignore persistable vars + } + } + + for (auto op_desc : block.AllOps()) { + for (auto& input : op_desc->Inputs()) { + for (auto& input_name : input.second) { + if (!ignored_vars.count(input_name)) { + if (ref_cnts.count(input_name)) + ++ref_cnts[input_name]; + else + ref_cnts[input_name] = 1; + } + } + } + + for (auto& output : op_desc->Outputs()) { + for (auto output_name : output.second) { + if (!ignored_vars.count(output_name)) { + if (ref_cnts.count(output_name)) + ++ref_cnts[output_name]; + else + ref_cnts[output_name] = 1; + } + } + } + } + return ref_cnts; +} + struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ~ExecutorPrepareContext(); @@ -34,6 +75,8 @@ struct ExecutorPrepareContext { const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; + + std::unordered_map ref_cnts_; }; class Executor { diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h new file mode 100644 index 0000000000000000000000000000000000000000..b403252c972d26da6deeca54ce88a9547ffe7afa --- /dev/null +++ b/paddle/fluid/framework/garbage_collector.h @@ -0,0 +1,163 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include // NOLINT +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { + +// T should have memory_size() and clear() method +template +class GarbageCollector { + public: + GarbageCollector(const platform::Place &place, size_t max_memory_size) + : max_memory_size_(std::max(max_memory_size, static_cast(1))) { + garbages_.reset(new std::deque()); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); + } + + virtual ~GarbageCollector() {} + + void Reset() { + std::lock_guard guard(mutex_); + garbages_.reset(new std::deque()); + cur_memory_size_ = 0; + } + + template + void Add(const Container &objs) { + Add(objs, []() {}); + } + + template + void Add(const Container &objs, Callback &&callback) { + std::shared_ptr> clear_deque; + { + std::lock_guard guard(mutex_); + for (auto *obj : objs) { + garbages_->push_back(obj); + cur_memory_size_ += obj->memory_size(); + } + if (cur_memory_size_ >= max_memory_size_) { + cur_memory_size_ = 0; + clear_deque = garbages_; + garbages_.reset(new std::deque()); + } + } + + if (clear_deque != nullptr) { + callback(); + ClearCallback([=]() { + for (auto *obj : *clear_deque) obj->clear(); + }); + } + } + + virtual void Wait() const {} + + protected: + virtual void ClearCallback(const std::function &callback) = 0; + + platform::DeviceContext *dev_ctx_; + std::shared_ptr> garbages_; + mutable std::mutex mutex_; + const size_t max_memory_size_; + size_t cur_memory_size_ = 0; +}; + +template +class CPUGarbageCollector : public GarbageCollector { + public: + CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + protected: + void ClearCallback(const std::function &callback) override { + callback(); + } +}; + +#ifdef PADDLE_WITH_CUDA +template +class DefaultStreamGarbageCollector : public GarbageCollector { + public: + DefaultStreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + cudaStream_t stream() const { + return static_cast(this->dev_ctx_) + ->stream(); + } + + void Wait() const override { + this->dev_ctx_->Wait(); + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); + } + + protected: + void ClearCallback(const std::function &callback) override { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); + } +}; + +template +class StreamGarbageCollector : public GarbageCollector { + public: + StreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + callback_manager_.reset(new platform::StreamCallbackManager(stream_)); + } + + ~StreamGarbageCollector() { + auto place = boost::get(this->dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + } + + void Wait() const override { + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + std::lock_guard guard(this->mutex_); + callback_manager_->Wait(); + } + + cudaStream_t stream() const { return stream_; } + + protected: + void ClearCallback(const std::function &callback) override { + std::lock_guard guard(this->mutex_); + callback_manager_->AddCallback(callback); + } + + private: + cudaStream_t stream_; + std::unique_ptr callback_manager_; +}; +#endif + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ae8496204d4aeb88c04154d571325d440274e821..ab687e760a761d4e445726bd5149966adc2403d0 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -94,6 +94,14 @@ class Graph { }; } + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = []() {}; + } + const std::unordered_set &Nodes() const { return node_set_; } // Create a normal variable with non-null VarDesc. diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5b8c75a93de2ddd8f7260d2191c22a5945b3d2d9..ae393d66a3b3ec0141667b44b5d9f3158e434e37 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -188,6 +188,30 @@ ParallelExecutor::ParallelExecutor( main_program, member_->places_, loss_var_name, params, member_->local_scopes_, member_->use_cuda_, build_strategy, member_->nccl_ctxs_.get()); + + auto max_memory_size = GetEagerDeletionThreshold(); + if (max_memory_size >= 0) { + for (auto &place : member_->places_) { + if (!platform::is_gpu_place(place)) continue; + auto gpu_place = boost::get(place); + if (gcs_[gpu_place.device] == nullptr) { + ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap()); + cur_ref_cnts_[gpu_place.device].reset( + new details::AtomicReferenceCountMap()); + gcs_[gpu_place.device].reset( + new StreamGarbageCollector(gpu_place, max_memory_size)); + } + } + if (!gcs_.empty()) { + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); + graph = ref_cnt_pass->Apply(std::move(graph)); + graph->SetNotOwned("garbage_collector", &gcs_); + } + } #else std::unique_ptr graph = ApplyParallelExecutorPass( main_program, member_->places_, loss_var_name, params, @@ -310,6 +334,11 @@ void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { platform::RecordBlock b(0); +#ifdef PADDLE_WITH_CUDA + if (!gcs_.empty()) { + ResetReferenceCount(); + } +#endif auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; @@ -367,3 +396,6 @@ USE_PASS(graph_viz_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); +#ifdef PADDLE_WITH_CUDA +USE_PASS(reference_count_pass); +#endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 5fb748fa205d5e9dbd2943b615c69aedd0e7a26f..88e2078454024c3a4d437615d3e6b15ee0c7d6a1 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include "paddle/fluid/framework/details/execution_strategy.h" @@ -27,6 +29,10 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_pass.h" +#endif + namespace paddle { namespace framework { @@ -70,6 +76,23 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; + +#ifdef PADDLE_WITH_CUDA + // ref_cnts_ is only initialized when ParallelExecutor constructs, and then + // keeps unchanged + // Before each iteration, cur_ref_cnts_ is reset to ref_cnts_ + details::DeviceReferenceCountMap ref_cnts_; + details::AtomicDeviceReferenceCountMap cur_ref_cnts_; + details::DeviceGarbageCollectorMap gcs_; + + void ResetReferenceCount() { + for (auto &pair1 : ref_cnts_) { + for (auto &pair2 : *(pair1.second)) { + (*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second; + } + } + } +#endif }; } // namespace framework diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 2be655b89a4caf2bf9874dcab6bc0bdb2856a026..1a727a2c8c759d010606d5b605823b7252b35c69 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -31,9 +31,21 @@ DEFINE_bool( "Delete local scope eagerly. It will reduce GPU memory usage but " "slow down the destruction of variables.(around 1% performance harm)"); +DEFINE_double( + eager_delete_tensor_gb, -1.0, + "Memory size threshold (GB) when the garbage collector clear tensors." + "Disabled when this value is less than 0"); + namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold() { + return FLAGS_eager_delete_tensor_gb < 0 + ? -1 + : static_cast(FLAGS_eager_delete_tensor_gb * + (static_cast(1) << 30)); +} + Scope::~Scope() { DropKids(); } Scope& Scope::NewScope() const { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index b6165a595d537c314a95685e8b1edbc42e387ab7..e42fff1d79d92fb7ed61768a614d8cd98f6775a0 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -26,6 +26,8 @@ limitations under the License. */ namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold(); + class Scope; /** diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 4cf95fa0ae07823289fbf337062190f05e6c6bcf..f1d268548578fea12082e2edb213a3749eccbfaf 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -151,6 +151,8 @@ class Tensor { void set_layout(const DataLayout layout) { layout_ = layout; } + void clear() { holder_ = nullptr; } + private: /** * @note Placeholder hides type T, so it doesn't appear as a template diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index e25efebe6c3555958f4f75e2b87b7dc45d4a4177..5af8af640e43a5b2e5ee9856f09f66a9fdf4463c 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -51,7 +51,7 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS malloc +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index c6f1d1f3d544117311821d980300dffea03891a5..dfc079e986e93c7f02f17b299e5d6293edbedd05 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -210,11 +210,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) if (dynload::HasCUDNN()) { cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } + + callback_manager_.reset(new StreamCallbackManager(stream_)); } CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); + WaitStreamCallback(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); eigen_stream_.reset(); eigen_device_.reset(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3ed49fc4233d4c0cd6cc16319eda08480ab9b434..79539195157d74d4d757edee5e008cbb76c93ee2 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -31,6 +31,9 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/stream_callback_manager.h" +#endif #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { @@ -112,6 +115,17 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } + template + void AddStreamCallback(Callback&& callback) const { + std::lock_guard guard(callback_mtx_); + callback_manager_->AddCallback(callback); + } + + void WaitStreamCallback() const { + std::lock_guard guard(callback_mtx_); + callback_manager_->Wait(); + } + private: CUDAPlace place_; @@ -125,7 +139,12 @@ class CUDADeviceContext : public DeviceContext { int multi_process; int max_threads_per_mp; - std::mutex mtx_; + mutable std::mutex mtx_; + + // This lock is only used by callback + // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes + mutable std::mutex callback_mtx_; + std::unique_ptr callback_manager_; }; template <> diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..6c984065aa5fa1a8875aebe84051ab396bc417ec --- /dev/null +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +using StreamCallback = std::function; + +class StreamCallbackManager; + +struct StreamCallbackContext { + template + inline StreamCallbackContext(const StreamCallbackManager *manager, + Callback &&callback) + : manager_(manager), callback_(callback) {} + + const StreamCallbackManager *manager_; // do not own + StreamCallback callback_; +}; + +class StreamCallbackManager { + public: + explicit inline StreamCallbackManager(cudaStream_t stream = nullptr) + : stream_(stream), thread_pool_(new ThreadPool(1)) {} + + template + inline void AddCallback(Callback &&callback) const { + AddCallbackWithStreamAndErrorInfo( + [=](cudaStream_t, cudaError_t) { callback(); }); + } + + template + inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const { + auto *stream_callback_context = new StreamCallbackContext(this, callback); + PADDLE_ENFORCE(cudaStreamAddCallback( + stream_, StreamCallbackManager::StreamCallbackFunc, + stream_callback_context, 0)); + } + + void Wait() const { thread_pool_.reset(new ThreadPool(1)); } + + private: + const cudaStream_t stream_; + mutable std::unique_ptr thread_pool_; + + // cudaStreamCallback cannot call CUDA API inside, so we have to use + // thread_pool here + static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, + cudaError_t status, + void *user_data) { + auto *callback_context_ptr = + reinterpret_cast(user_data); + callback_context_ptr->manager_->thread_pool_->enqueue([=]() { + std::unique_ptr callback_context( + callback_context_ptr); + callback_context->callback_(stream, status); + }); + } +}; + +} // namespace platform +} // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 9aac3c7fc16ae1ded2700662764895385b043130..1ca2ac2ddc7daef3f4c0ea2004a62258ae4610ac 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -122,7 +122,7 @@ def __bootstrap__(): 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - "dist_threadpool_size", 'cpu_deterministic' + "dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_gb' ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline')