diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4fb4ec38ee965a2790d11378a1ce6befa0ef5a00..8404bf4a3e12bdd33c063678d9288bbb24bd8aea 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -29,13 +29,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) +if(WITH_GPU) + cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle + all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) +endif() +cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) +if(WITH_GPU) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) +else() + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) +endif() -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index f048f973fdeb6cf7d1485cda8cea7d530d9ba465..401ebb7953bb5d6c81d1e5206598c4b0ee5904c8 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -23,6 +23,8 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/framework/details/reference_count_op_handle.h" + namespace paddle { namespace framework { namespace details { @@ -33,6 +35,10 @@ struct ComputationOpHandle : public OpHandleBase { std::string Name() const override; + const Scope *GetScope() const { return scope_; } + + const platform::Place &GetPlace() const { return place_; } + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 6aec178831161f8ac1306fc3ed72e3267ca3c7e5..3de22a0235ffaae220a00c7253271b395970082a 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -82,6 +82,13 @@ class OpHandleBase { size_t NoDummyInputSize() const; + ir::Node *Node() { return node_; } + + const std::map + &GetDeviceContexts() const { + return dev_ctxes_; + } + protected: void RunAndRecordEvent(const std::function &callback); diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..b76fc646c297295795d782ac869299642e2114ca --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -0,0 +1,123 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace framework { +namespace details { + +using ReferenceCountMap = std::unordered_map; +using AtomicReferenceCountMap = + std::unordered_map>; +using DeviceReferenceCountMap = + std::unordered_map>; +using AtomicDeviceReferenceCountMap = + std::unordered_map>; +using DeviceGarbageCollectorMap = + std::unordered_map>>; + +class ReferenceCountOpHandle : public OpHandleBase { + public: + ReferenceCountOpHandle(ir::Node *node, const Scope *scope, + const platform::CUDAPlace &place, + const std::vector &var_names, + GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts) + : OpHandleBase(node), + scope_(scope), + var_names_(var_names), + gc_(gc), + ref_cnts_(ref_cnts) { + dev_ctx_ = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + if (IsStreamGarabageCollector()) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + } + + ~ReferenceCountOpHandle() { + if (IsStreamGarabageCollector()) { + auto gpu_place = boost::get(dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); + PADDLE_ENFORCE(cudaEventDestroy(event_)); + } + } + + std::string Name() const override { return "reference_count"; } + + // protected: + void RunImpl() override { + auto *exec_scope_ = scope_->FindVar(kLocalExecScopeName)->Get(); + std::vector tensors; + for (auto &name : var_names_) { + auto it = ref_cnts_->find(name); + if (it == ref_cnts_->end()) continue; + + auto *var = exec_scope_->FindVar(name); + if (var == nullptr || !var->IsType()) continue; + + if (it->second.fetch_sub(1) <= 1) { + tensors.emplace_back(var->GetMutable()); + } + } + + if (!tensors.empty()) { + ClearTensors(tensors); + } + } + + private: + void ClearTensors(const std::vector &tensors) const { + auto *gc = dynamic_cast *>(gc_); + if (gc != nullptr) { + auto compute_stream = dev_ctx_->stream(); + auto callback_stream = gc->stream(); + auto callback_func = [=]() { + PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); + PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); + }; + gc_->Add(tensors, callback_func); + } else { + gc_->Add(tensors); + } + } + + bool IsStreamGarabageCollector() const { + return dynamic_cast *>(gc_) != nullptr; + } + + const Scope *scope_; + platform::CUDADeviceContext *dev_ctx_; + std::vector var_names_; + GarbageCollector *gc_; // not own + AtomicReferenceCountMap *ref_cnts_; // not own + cudaEvent_t event_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..892e6ea48a1460b83f92d9d8e7aa14b52513727f --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/reference_count_pass.h" + +namespace paddle { +namespace framework { +namespace details { + +std::unique_ptr ReferenceCountPass::ApplyImpl( + std::unique_ptr graph) const { + auto &ref_cnts = Get(kGlobalReferenceCount); + auto &cur_ref_cnts = Get(kCurReferenceCount); + auto &gcs = Get(kGarbageCollector); + + // It is not easy to find the right reference counts of varaibles in graph + // Step 1: Find all variables in computation ops + // Step 2: Find all variables in non-computation ops which refers to variables + // in computation ops + std::unordered_set names; + auto get_ref_cnts_from_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + std::vector var_names_in_op; + auto *compute_op = dynamic_cast(op.get()); + if (compute_op == nullptr || + !platform::is_gpu_place(compute_op->GetPlace())) + return var_names_in_op; + auto place = boost::get(compute_op->GetPlace()); + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + if (!platform::is_gpu_place(var_handle->place_) || + boost::get(var_handle->place_) != place) + continue; + + VarDesc *var_desc = var_handle->Node()->Var(); + auto var_name = var_handle->Node()->Name(); + + // This is wierd but there is really some variables without var_desc + // in computation_op + if (var_desc == nullptr) { + if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) + continue; + } else { + if (var_desc->Persistable() || + var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) + continue; + } + + // compute op only runs in one device + if (ref_cnts[place.device]->count(var_name)) + ++(*ref_cnts[place.device])[var_name]; + else + (*ref_cnts[place.device])[var_name] = 1; + + names.insert(var_name); + var_names_in_op.push_back(var_name); + } + return var_names_in_op; + }; + + auto update_ref_cnts_from_non_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + if (dynamic_cast(op.get()) != nullptr) return; + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + auto var_name = var_handle->Node()->Name(); + auto var_place = var_handle->place_; + if (!platform::is_gpu_place(var_place)) continue; + auto place = boost::get(var_place); + if (names.count(var_name) == 0) continue; + if (ref_cnts.count(place.device) && + ref_cnts[place.device]->count(var_name)) { + ++(*ref_cnts[place.device])[var_name]; + } + } + }; + + std::unordered_map + compute_ref_cnt_map; + auto &all_ops = graph->Get(kGraphOps); + for (auto &op : all_ops) { + auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); + auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); + if (in_var_names.empty() && out_var_names.empty()) continue; + in_var_names.insert(in_var_names.end(), out_var_names.begin(), + out_var_names.end()); + auto *compute_op = dynamic_cast(op.get()); + auto place = boost::get(compute_op->GetPlace()); + ir::Node *ref_cnt_node = + graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); + auto *ref_cnt_handle = new ReferenceCountOpHandle( + ref_cnt_node, compute_op->GetScope(), place, in_var_names, + gcs[place.device].get(), cur_ref_cnts[place.device].get()); + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + compute_op->AddOutput(dep_var); + ref_cnt_handle->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + compute_ref_cnt_map[compute_op] = ref_cnt_handle; + } + + for (auto &op : all_ops) { + update_ref_cnts_from_non_compute_op(op, op->Inputs()); + update_ref_cnts_from_non_compute_op(op, op->Outputs()); + } + + std::vector> new_all_ops; + new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); + for (auto &op : all_ops) { + auto it = compute_ref_cnt_map.find(op.get()); + if (it != compute_ref_cnt_map.end()) { + new_all_ops.emplace_back(std::move(op)); + new_all_ops.emplace_back(std::unique_ptr(it->second)); + } else { + new_all_ops.emplace_back(std::move(op)); + } + } + + all_ops.swap(new_all_ops); + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reference_count_pass, + paddle::framework::details::ReferenceCountPass) + .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) + .RequirePassAttr(paddle::framework::details::kCurReferenceCount) + .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7081280b0600b9c1985987d02d679c298ad4b8bd --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +constexpr char kGlobalReferenceCount[] = "reference_count"; +constexpr char kCurReferenceCount[] = "current_reference_count"; +constexpr char kGarbageCollector[] = "garbage_collector"; + +class ReferenceCountPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index eb4e7ec52f907f9403e21ec2734d61824f51a58b..51e840ffa6c8e9818058cdbb87d631f0004e9d93 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -16,6 +16,10 @@ #include #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#endif namespace paddle { namespace framework { @@ -56,12 +60,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( auto fetch_data = underlying_executor_->Run(fetch_tensors); drop_scope_counter_ += 1; + +#ifdef PADDLE_WITH_CUDA + const std::string gc_name = "garbage_collector"; + DeviceGarbageCollectorMap *gc = + Graph().Has(gc_name) ? &(Graph().Get(gc_name)) + : nullptr; +#endif + if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams for (auto p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); +#ifdef PADDLE_WITH_CUDA + if (gc != nullptr && platform::is_gpu_place(p)) { + auto gpu_place = boost::get(p); + auto &gc_at_place = gc->at(gpu_place.device); + gc_at_place->Wait(); + gc_at_place->Reset(); + } +#endif } for (auto &scope : local_scopes_) { auto &local_scope = diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 84f67fafa19ac545ebb7a1019059e3c74c363c56..6868f639a03bef9acdd9d6418883e7c502761ec5 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,7 +37,9 @@ int kProgramId = -1; ExecutorPrepareContext::ExecutorPrepareContext( const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), block_id_(block_id) {} + : prog_(prog), + block_id_(block_id), + ref_cnts_(GetNonPersistableReferenceCount(prog, block_id)) {} ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; @@ -335,20 +337,84 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, CreateVariables(ctx->prog_, local_scope, ctx->block_id_); } + std::shared_ptr> erase_tensors( + new std::vector()); + int64_t max_memory_size = GetEagerDeletionThreshold(); + + std::unique_ptr> gc; + if (max_memory_size >= 0) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place_)) { + gc.reset(new DefaultStreamGarbageCollector( + boost::get(place_), max_memory_size)); + } else { +#endif + gc.reset(new CPUGarbageCollector( + boost::get(place_), max_memory_size)); +#ifdef PADDLE_WITH_CUDA + } +#endif + } + for (auto& op : ctx->ops_) { VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); - // NOTE! Please do not delete this line, it's usefull because the debug - // string before and after op.run are different, after run the output - // will have right shape which is usefull for debug. - VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); + +#ifdef PADDLE_WITH_CUDA + if (gc != nullptr) { + std::vector erase_vars; + for (auto& input : op->Inputs()) { + for (auto& input_name : input.second) { + auto it = ctx->ref_cnts_.find(input_name); + if (it == ctx->ref_cnts_.end()) continue; + if (it->second == 1) { // should delete it + erase_vars.emplace_back(input_name); + ctx->ref_cnts_.erase(input_name); + } else { + --(it->second); + } + } + } + + for (auto& output : op->Outputs()) { + for (auto& output_name : output.second) { + auto it = ctx->ref_cnts_.find(output_name); + if (it == ctx->ref_cnts_.end()) continue; + if (it->second == 1) { + erase_vars.emplace_back(output_name); + ctx->ref_cnts_.erase(output_name); + } else { + --(it->second); + } + } + } + + if (!erase_vars.empty()) { + std::vector erase_tensors; + for (auto& name : erase_vars) { + auto* var = local_scope->FindVar(name); + if (var == nullptr) continue; + if (var->IsType()) { + auto* tensor = var->GetMutable(); + erase_tensors.push_back(tensor); + } + } + if (!erase_tensors.empty()) gc->Add(erase_tensors); + } + } +#endif if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " << memory::memory_usage(place_); } } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + + if (gc != nullptr) + gc->Wait(); + else + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + if (local_scope != scope) { scope->DeleteScope(local_scope); } else { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 563a4b2bb65dad481a755f67c7f23939816ce8e8..81d83ecea50e360b6c1935777dd246f012160d5a 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -27,6 +28,48 @@ namespace paddle { namespace framework { extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); +int64_t GetEagerDeletionThreshold(); + +template +std::unordered_map GetNonPersistableReferenceCount( + const ProgramDesc& prog, size_t block_id) { + auto& block = prog.Block(block_id); + std::unordered_set ignored_vars; + std::unordered_map ref_cnts; + + for (auto var_desc : block.AllVars()) { + auto type = var_desc->Proto()->type().type(); + if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { + ignored_vars.insert(var_desc->Name()); // ignore persistable vars + } + } + + for (auto op_desc : block.AllOps()) { + for (auto& input : op_desc->Inputs()) { + for (auto& input_name : input.second) { + if (!ignored_vars.count(input_name)) { + if (ref_cnts.count(input_name)) + ++ref_cnts[input_name]; + else + ref_cnts[input_name] = 1; + } + } + } + + for (auto& output : op_desc->Outputs()) { + for (auto output_name : output.second) { + if (!ignored_vars.count(output_name)) { + if (ref_cnts.count(output_name)) + ++ref_cnts[output_name]; + else + ref_cnts[output_name] = 1; + } + } + } + } + return ref_cnts; +} + struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ~ExecutorPrepareContext(); @@ -34,6 +77,8 @@ struct ExecutorPrepareContext { const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; + + std::unordered_map ref_cnts_; }; class Executor { diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h new file mode 100644 index 0000000000000000000000000000000000000000..b403252c972d26da6deeca54ce88a9547ffe7afa --- /dev/null +++ b/paddle/fluid/framework/garbage_collector.h @@ -0,0 +1,163 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include // NOLINT +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { + +// T should have memory_size() and clear() method +template +class GarbageCollector { + public: + GarbageCollector(const platform::Place &place, size_t max_memory_size) + : max_memory_size_(std::max(max_memory_size, static_cast(1))) { + garbages_.reset(new std::deque()); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); + } + + virtual ~GarbageCollector() {} + + void Reset() { + std::lock_guard guard(mutex_); + garbages_.reset(new std::deque()); + cur_memory_size_ = 0; + } + + template + void Add(const Container &objs) { + Add(objs, []() {}); + } + + template + void Add(const Container &objs, Callback &&callback) { + std::shared_ptr> clear_deque; + { + std::lock_guard guard(mutex_); + for (auto *obj : objs) { + garbages_->push_back(obj); + cur_memory_size_ += obj->memory_size(); + } + if (cur_memory_size_ >= max_memory_size_) { + cur_memory_size_ = 0; + clear_deque = garbages_; + garbages_.reset(new std::deque()); + } + } + + if (clear_deque != nullptr) { + callback(); + ClearCallback([=]() { + for (auto *obj : *clear_deque) obj->clear(); + }); + } + } + + virtual void Wait() const {} + + protected: + virtual void ClearCallback(const std::function &callback) = 0; + + platform::DeviceContext *dev_ctx_; + std::shared_ptr> garbages_; + mutable std::mutex mutex_; + const size_t max_memory_size_; + size_t cur_memory_size_ = 0; +}; + +template +class CPUGarbageCollector : public GarbageCollector { + public: + CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + protected: + void ClearCallback(const std::function &callback) override { + callback(); + } +}; + +#ifdef PADDLE_WITH_CUDA +template +class DefaultStreamGarbageCollector : public GarbageCollector { + public: + DefaultStreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + cudaStream_t stream() const { + return static_cast(this->dev_ctx_) + ->stream(); + } + + void Wait() const override { + this->dev_ctx_->Wait(); + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); + } + + protected: + void ClearCallback(const std::function &callback) override { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); + } +}; + +template +class StreamGarbageCollector : public GarbageCollector { + public: + StreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + callback_manager_.reset(new platform::StreamCallbackManager(stream_)); + } + + ~StreamGarbageCollector() { + auto place = boost::get(this->dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + } + + void Wait() const override { + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + std::lock_guard guard(this->mutex_); + callback_manager_->Wait(); + } + + cudaStream_t stream() const { return stream_; } + + protected: + void ClearCallback(const std::function &callback) override { + std::lock_guard guard(this->mutex_); + callback_manager_->AddCallback(callback); + } + + private: + cudaStream_t stream_; + std::unique_ptr callback_manager_; +}; +#endif + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h new file mode 100644 index 0000000000000000000000000000000000000000..ab687e760a761d4e445726bd5149966adc2403d0 --- /dev/null +++ b/paddle/fluid/framework/ir/graph.h @@ -0,0 +1,183 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * The graph is a Directed Acyclic Single Static Assignment Graph. + * + * In more detail, the following properties must hold: + * + * The graph shouldn't contain cycle. Each node is a black-box to the graph + * so the node itself could be a loop operator. + * + * Each Variable-type node has only one input (thus single static assignment). + * + * The output/input of operator is variable and the output/input of variable + * is operator. + * + * The following data harzards in Program are addressed in the Graph: + * + * Write-After-Read + * a = op1(x) + * x = op2(b) + * A control-dependency connection is created bettwen op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Write-After-Write + * x = op1(a) + * x = op2(b) + * A control-dependency connection is created between op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Other properties currently hold, but is not enforced yet: + * + * Variable-type node (not control dep) with the same variable name share + * the same underlying VarDesc. + */ +class Graph { + public: + explicit Graph(const ProgramDesc &program); + + virtual ~Graph() { + for (auto &attr : attrs_) { + attr_dels_[attr.first](); + } + attrs_.clear(); + attr_dels_.clear(); + } + + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + + template + AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", + attr_name); + return *boost::any_cast(attrs_.at(attr_name)); + } + + template + void Set(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; + } + + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = []() {}; + } + + const std::unordered_set &Nodes() const { return node_set_; } + + // Create a normal variable with non-null VarDesc. + ir::Node *CreateVarNode(VarDesc *var_desc) { + PADDLE_ENFORCE(var_desc); + return AddNode(new ir::Node(var_desc)); + } + + // Create a normal runnable operator with OpDesc. + ir::Node *CreateOpNode(OpDesc *op_desc) { + PADDLE_ENFORCE(op_desc); + return AddNode(new ir::Node(op_desc)); + } + + // Create a control dependency var that connects 2 operations. The + // var doesn't hold any data. Other than that, it's no different from + // other var, considering dependency analysis. + ir::Node *CreateControlDepVar() { + // TODO(panyx0718): control var name should be really unique. + const std::string name = string::Sprintf( + "%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); + return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); + } + + // A more free style way of creating a graph node. Mostly use for test + // or "copy" from another node. Avoid using it if possible. + ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { + return AddNode(new ir::Node(name, type)); + } + + // Clear all node information of the graph and return the ownership of the + // nodes. + std::vector> ReleaseNodes() { + std::vector> ret; + for (auto &n : nodes_) { + ret.emplace_back(n.second.release()); + } + nodes_.clear(); + node_set_.clear(); + return ret; + } + + void RemoveNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); + node_set_.erase(node); + nodes_.erase(node); + } + + // NOTE low performance, but simple and secure. + Node *RetriveNode(int id) { + for (auto &node : nodes_) { + if (node.second->id() == id) { + return node.second.get(); + } + } + return nullptr; + } + + private: + // This method takes ownership of `node`. + ir::Node *AddNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); + nodes_[node].reset(node); + node_set_.insert(node); + return node; + } + + // NOTE: program_ shouldn't be exposed to user. + const ProgramDesc program_; + std::map attrs_; + std::map> attr_dels_; + std::map> nodes_; + std::unordered_set node_set_; +}; + +bool IsControlDepVar(const ir::Node &var); +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b53a6f43fbd1f23e69d23ad0fcc54d5c25d352a3..5a19e7f1bf94abaac6d13e963cab3779c0789b82 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -19,9 +19,15 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/platform/nccl_helper.h" #endif +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" @@ -115,17 +121,39 @@ ParallelExecutor::ParallelExecutor( build_strategy); if (member_->use_cuda_) { #ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy, + member_->nccl_ctxs_.get()); + + auto max_memory_size = GetEagerDeletionThreshold(); + if (max_memory_size >= 0) { + for (auto &place : member_->places_) { + if (!platform::is_gpu_place(place)) continue; + auto gpu_place = boost::get(place); + if (gcs_[gpu_place.device] == nullptr) { + ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap()); + cur_ref_cnts_[gpu_place.device].reset( + new details::AtomicReferenceCountMap()); + gcs_[gpu_place.device].reset( + new StreamGarbageCollector(gpu_place, max_memory_size)); + } + } + if (!gcs_.empty()) { + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); + graph = ref_cnt_pass->Apply(std::move(graph)); + graph->SetNotOwned("garbage_collector", &gcs_); + } + } #else PADDLE_THROW("Not compiled with CUDA"); #endif } - builder_ = builder_factory.Create(); - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, - builder_->Build(main_program))); - member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); @@ -216,6 +244,11 @@ void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { platform::RecordBlock b(0); +#ifdef PADDLE_WITH_CUDA + if (!gcs_.empty()) { + ResetReferenceCount(); + } +#endif auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; @@ -265,3 +298,11 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle + +USE_PASS(graph_viz_pass); +USE_PASS(multi_devices_pass); +USE_PASS(multi_devices_check_pass); +USE_PASS(multi_devices_print_pass); +#ifdef PADDLE_WITH_CUDA +USE_PASS(reference_count_pass); +#endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 058f83f07c26224e3180d140630c08a24c40cd80..2aa438e320a0f191f78a6274b8ad8453f1736ef4 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include "paddle/fluid/framework/details/execution_strategy.h" @@ -70,7 +72,23 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; - std::unique_ptr builder_; + +#ifdef PADDLE_WITH_CUDA + // ref_cnts_ is only initialized when ParallelExecutor constructs, and then + // keeps unchanged + // Before each iteration, cur_ref_cnts_ is reset to ref_cnts_ + details::DeviceReferenceCountMap ref_cnts_; + details::AtomicDeviceReferenceCountMap cur_ref_cnts_; + details::DeviceGarbageCollectorMap gcs_; + + void ResetReferenceCount() { + for (auto &pair1 : ref_cnts_) { + for (auto &pair2 : *(pair1.second)) { + (*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second; + } + } + } +#endif }; } // namespace framework diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 50f374e3703a97f6c1fdb4b14fdeb0b603f9ac86..caea191cb3513fbe701df0dca668d28fefb6a1d3 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -31,9 +31,21 @@ DEFINE_bool( "Delete local scope eagerly. It will reduce GPU memory usage but " "slow down the destruction of variables.(around 1% performance harm)"); +DEFINE_double( + eager_delete_tensor_GB, -1.0, + "Memory size threshold (GB) when the garbage collector clear tensors." + "Disabled when this value is less than 0"); + namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold() { + return FLAGS_eager_delete_tensor_GB < 0 + ? -1 + : static_cast(FLAGS_eager_delete_tensor_GB * + (static_cast(1) << 30)); +} + Scope::~Scope() { DropKids(); } Scope& Scope::NewScope() const { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index e246241c0abfbc7bdcaf38d073cc58fc36a4f737..47d040240a213f65153252419ebb429461e866c5 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -26,6 +26,8 @@ limitations under the License. */ namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold(); + class Scope; /** diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index ef224d68f1fc561f45e9d7a81425e62655457648..775c01765c96ecdc7c3aef5174b90c52ed281e69 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -149,6 +149,8 @@ class Tensor { void set_layout(const DataLayout layout) { layout_ = layout; } + void clear() { holder_ = nullptr; } + private: /** * @note Placeholder hides type T, so it doesn't appear as a template diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 20037d0764056c2a093af801c9cc1eb788dd46d6..ac9bf9a505d2f03ba511c9a65ec6851cf605ab8b 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -45,8 +45,8 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS malloc - place eigen3 stringpiece cpu_helper ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) cc_test(init_test SRCS init_test.cc DEPS device_context) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 2cc26da013f59f5b7ee1747d57baca9c1c0efe2c..a57ee2d8f5e598adc33e7fd8f1f354d9a372cd12 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -159,11 +159,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { } else { cudnn_handle_ = nullptr; } + + callback_manager_.reset(new StreamCallbackManager(stream_)); } CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); + WaitStreamCallback(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); if (cudnn_handle_ != nullptr) { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 88e0383146c1adf2752a362091996bad9cfcce5e..0fb53383685221a3415a396ff6b712ccddd011c3 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -31,8 +31,13 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/stream_callback_manager.h" +#endif #include "unsupported/Eigen/CXX11/Tensor" +DECLARE_bool(clear_gpu_memory_when_unused); + namespace paddle { namespace platform { @@ -106,6 +111,17 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } + template + void AddStreamCallback(Callback&& callback) const { + std::lock_guard guard(callback_mtx_); + callback_manager_->AddCallback(callback); + } + + void WaitStreamCallback() const { + std::lock_guard guard(callback_mtx_); + callback_manager_->Wait(); + } + private: CUDAPlace place_; @@ -119,7 +135,12 @@ class CUDADeviceContext : public DeviceContext { int multi_process; int max_threads_per_mp; - std::mutex mtx_; + mutable std::mutex mtx_; + + // This lock is only used by callback + // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes + mutable std::mutex callback_mtx_; + std::unique_ptr callback_manager_; }; template <> diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..6c984065aa5fa1a8875aebe84051ab396bc417ec --- /dev/null +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +using StreamCallback = std::function; + +class StreamCallbackManager; + +struct StreamCallbackContext { + template + inline StreamCallbackContext(const StreamCallbackManager *manager, + Callback &&callback) + : manager_(manager), callback_(callback) {} + + const StreamCallbackManager *manager_; // do not own + StreamCallback callback_; +}; + +class StreamCallbackManager { + public: + explicit inline StreamCallbackManager(cudaStream_t stream = nullptr) + : stream_(stream), thread_pool_(new ThreadPool(1)) {} + + template + inline void AddCallback(Callback &&callback) const { + AddCallbackWithStreamAndErrorInfo( + [=](cudaStream_t, cudaError_t) { callback(); }); + } + + template + inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const { + auto *stream_callback_context = new StreamCallbackContext(this, callback); + PADDLE_ENFORCE(cudaStreamAddCallback( + stream_, StreamCallbackManager::StreamCallbackFunc, + stream_callback_context, 0)); + } + + void Wait() const { thread_pool_.reset(new ThreadPool(1)); } + + private: + const cudaStream_t stream_; + mutable std::unique_ptr thread_pool_; + + // cudaStreamCallback cannot call CUDA API inside, so we have to use + // thread_pool here + static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, + cudaError_t status, + void *user_data) { + auto *callback_context_ptr = + reinterpret_cast(user_data); + callback_context_ptr->manager_->thread_pool_->enqueue([=]() { + std::unique_ptr callback_context( + callback_context_ptr); + callback_context->callback_(stream, status); + }); + } +}; + +} // namespace platform +} // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 3034c1a0875a71421bcba172c16ee32d809df152..74b268aedece0983177616581f0755ea16916697 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -117,9 +117,19 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) read_env_flags = [ - 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', - 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', - 'init_allocated_mem' + 'use_pinned_memory', + 'check_nan_inf', + 'benchmark', + 'warpctc_dir', + 'eager_delete_scope', + 'use_mkldnn', + 'initial_cpu_memory_in_mb', + 'init_allocated_mem', + 'free_idle_memory', + 'paddle_num_threads', + "dist_threadpool_size", + 'cpu_deterministic', + 'eager_delete_tensor_GB', ] if core.is_compiled_with_cuda(): read_env_flags += [