From 096673f67527b0fed1aab1843041b9d929fd0fb5 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Thu, 29 Nov 2018 13:20:29 +0000 Subject: [PATCH] refactor eager deletion test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 12 +- .../details/computation_op_handle.cc | 6 +- .../framework/details/computation_op_handle.h | 6 +- .../details/eager_deletion_op_handle.cc | 117 ++++++++++ .../details/eager_deletion_op_handle.h | 64 ++++++ .../framework/details/eager_deletion_pass.cc | 96 ++++++++ .../framework/details/eager_deletion_pass.h | 32 +++ .../details/multi_devices_graph_pass.cc | 6 +- .../details/reference_count_op_handle.h | 138 ------------ .../framework/details/reference_count_pass.cc | 213 +++++------------- .../framework/details/reference_count_pass.h | 5 - .../details/reference_count_pass_helper.h | 49 ++++ .../scope_buffered_ssa_graph_executor.cc | 30 +-- .../scope_buffered_ssa_graph_executor.h | 4 + paddle/fluid/framework/garbage_collector.h | 12 +- paddle/fluid/framework/ir/graph.h | 11 +- paddle/fluid/framework/ir/pass.h | 11 +- paddle/fluid/framework/parallel_executor.cc | 106 ++++++--- paddle/fluid/framework/parallel_executor.h | 24 +- paddle/fluid/platform/CMakeLists.txt | 9 +- .../fluid/platform/stream_callback_manager.cc | 70 ++++++ .../fluid/platform/stream_callback_manager.h | 51 +---- 22 files changed, 631 insertions(+), 441 deletions(-) create mode 100644 paddle/fluid/framework/details/eager_deletion_op_handle.cc create mode 100644 paddle/fluid/framework/details/eager_deletion_op_handle.h create mode 100644 paddle/fluid/framework/details/eager_deletion_pass.cc create mode 100644 paddle/fluid/framework/details/eager_deletion_pass.h delete mode 100644 paddle/fluid/framework/details/reference_count_op_handle.h create mode 100644 paddle/fluid/framework/details/reference_count_pass_helper.h create mode 100644 paddle/fluid/platform/stream_callback_manager.cc diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 93288936fea..8cf97d667d4 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -33,10 +33,9 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) -if (WITH_GPU) - cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle - all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) -endif() +cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base) +cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) +cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) @@ -44,10 +43,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass) -if (WITH_GPU) - list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) -endif() +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 7ad1e40c600..7beb8c8de9f 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -20,11 +20,13 @@ namespace paddle { namespace framework { namespace details { ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, - platform::Place place) + platform::Place place, + size_t scope_idx) : OpHandleBase(node), op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), - place_(place) {} + place_(place), + scope_idx_(scope_idx) {} void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 662a91d6b4d..601ae4f8c6d 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,7 +28,8 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, + size_t scope_idx); std::string Name() const override; @@ -38,6 +39,8 @@ struct ComputationOpHandle : public OpHandleBase { void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } + size_t GetScopeIdx() const { return scope_idx_; } + protected: void RunImpl() override; @@ -47,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; + size_t scope_idx_; bool is_lock_and_record_event_free_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc new file mode 100644 index 00000000000..cd262033760 --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" + +namespace paddle { +namespace framework { +namespace details { + +EagerDeletionOpHandle::EagerDeletionOpHandle( + ir::Node *node, const Scope *scope, const platform::Place &place, + const std::vector &var_names, GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts) + : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place)) { + dev_ctx_ = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + if (dynamic_cast *>(gc_)) { + platform::SetDeviceId(boost::get(place).device); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + } +#endif + + for (auto &name : var_names) AddVar(name); +} + +EagerDeletionOpHandle::~EagerDeletionOpHandle() { +#ifdef PADDLE_WITH_CUDA + if (event_) { + auto gpu_place = boost::get(dev_ctx_->GetPlace()); + platform::SetDeviceId(gpu_place.device); + PADDLE_ENFORCE(cudaEventDestroy(event_)); + } +#endif +} + +std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } + +void EagerDeletionOpHandle::AddVar(const std::string &name) { + var_names_.insert(name); +} + +void EagerDeletionOpHandle::RunImpl() { + auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); + std::vector tensors; + for (auto &name : var_names_) { + auto it = ref_cnts_->find(name); + if (it == ref_cnts_->end()) { + continue; + } + + auto *var = exec_scope->FindVar(name); + if (var == nullptr) { + continue; + } + + if (var->IsType()) { + if (it->second.fetch_sub(1) == 1) { + tensors.emplace_back(var->GetMutable()); + } + } else if (var->IsType()) { + if (it->second.fetch_sub(1) == 1) { + tensors.emplace_back(var->GetMutable()->mutable_value()); + } + } else if (var->IsType()) { + if (it->second.fetch_sub(1) == 1) { + auto *tensor_arr = var->GetMutable(); + for (auto &t : *tensor_arr) { + tensors.emplace_back(&t); + } + } + } + } + + if (!tensors.empty()) { + ClearTensors(tensors); + } +} + +void EagerDeletionOpHandle::ClearTensors(const std::vector &tensors) { +#ifdef PADDLE_WITH_CUDA + if (event_) { + auto compute_stream = dev_ctx_->stream(); + auto callback_stream = + static_cast *>(gc_)->stream(); + auto callback_func = [=]() { + PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); + PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); + }; + gc_->Add(tensors, callback_func); + } else { +#endif + gc_->Add(tensors); +#ifdef PADDLE_WITH_CUDA + } +#endif +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.h b/paddle/fluid/framework/details/eager_deletion_op_handle.h new file mode 100644 index 00000000000..8254f21bdfc --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.h @@ -0,0 +1,64 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace details { + +class EagerDeletionPass; + +class EagerDeletionOpHandle : public OpHandleBase { + public: + EagerDeletionOpHandle(ir::Node *node, const Scope *scope, + const platform::Place &place, + const std::vector &var_names, + GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts); + + ~EagerDeletionOpHandle(); + + std::string Name() const override; + + protected: + void RunImpl() override; + + private: + void ClearTensors(const std::vector &tensors); + + void AddVar(const std::string &name); + + const Scope *scope_; + std::unordered_set var_names_; + GarbageCollector *gc_; // not own + AtomicReferenceCountMap *ref_cnts_; // not own +#ifdef PADDLE_WITH_CUDA + platform::CUDADeviceContext *dev_ctx_{nullptr}; + cudaEvent_t event_{nullptr}; +#endif + + friend class EagerDeletionPass; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc new file mode 100644 index 00000000000..f877c2881cd --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_pass.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_pass.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out, + ir::Graph *graph) { + auto it = std::find_if( + in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) { + return dynamic_cast(var) != nullptr; + }); + + if (it != in->Outputs().end()) { + out->AddInput(*it); + } else { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dep_var); + in->AddOutput(dep_var); + out->AddInput(dep_var); + } + + // Add leaf node to eager_deletion_node + if (out->Outputs().empty()) { + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dummy_leaf); + out->AddOutput(dummy_leaf); + } +} + +std::unique_ptr EagerDeletionPass::ApplyImpl( + std::unique_ptr graph) const { + auto &vars = graph->Get(kGraphVars); + + auto &ref_cnts = + Get>(kCurReferenceCount); + auto &last_live_ops = Get>(kLastLiveOpsOfVars); + auto &gcs = Get(kGarbageCollector); + + ref_cnts = std::vector(vars.size()); + + std::unordered_map op_map; + for (auto &var_ops_map : last_live_ops) { + for (auto &var_ops_pair : var_ops_map) { + const std::string &var_name = var_ops_pair.first; + for (ComputationOpHandle *op : var_ops_pair.second) { + auto it = op_map.find(op); + if (it != op_map.end()) { + it->second->AddVar(var_name); + } else { + auto *eager_deletion_node = graph->CreateEmptyNode( + "eager_deletion", ir::Node::Type::kOperation); + auto *eager_deletion_op = new EagerDeletionOpHandle( + eager_deletion_node, op->GetScope(), op->GetPlace(), {var_name}, + gcs[op->GetScopeIdx()].get(), &(ref_cnts[op->GetScopeIdx()])); + AddDependencyBetween(op, eager_deletion_op, graph.get()); + op_map[op] = eager_deletion_op; + } + } + } + } + VLOG(10) << "Create " << op_map.size() << " EagerDeletionOpHandle(s)"; + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(eager_deletion_pass, + paddle::framework::details::EagerDeletionPass) + .RequirePassAttr(paddle::framework::details::kCurReferenceCount) + .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) + .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/eager_deletion_pass.h b/paddle/fluid/framework/details/eager_deletion_pass.h new file mode 100644 index 00000000000..d7a7a9709d9 --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class EagerDeletionPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index a36ad259265..97830386e42 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -562,7 +562,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), - local_scopes_[dev_id], places_[dev_id])); + local_scopes_[dev_id], places_[dev_id], dev_id)); CreateOpHandleIOs(result, node, dev_id); } @@ -685,8 +685,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get(kGraphOps).emplace_back( - new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); + result->Get(kGraphOps).emplace_back(new ComputationOpHandle( + result->CreateOpNode(node->Op()), s, p, scope_idx)); CreateOpHandleIOs(result, node, scope_idx); } } diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h deleted file mode 100644 index cc4ccfbdfc7..00000000000 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/garbage_collector.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor.h" - -namespace paddle { -namespace framework { -namespace details { - -using ReferenceCountMap = std::unordered_map; -using AtomicReferenceCountMap = - std::unordered_map>; -using DeviceReferenceCountMap = - std::unordered_map>; -using AtomicDeviceReferenceCountMap = - std::unordered_map>; -using DeviceGarbageCollectorMap = - std::unordered_map>>; - -class ReferenceCountOpHandle : public OpHandleBase { - public: - ReferenceCountOpHandle(ir::Node *node, const Scope *scope, - const platform::CUDAPlace &place, - const std::vector &var_names, - GarbageCollector *gc, - AtomicReferenceCountMap *ref_cnts) - : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) { - dev_ctx_ = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - if (IsStreamGarabageCollector()) { - platform::SetDeviceId(place.device); - PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); - } - - for (auto &name : var_names) AddVar(name); - } - - ~ReferenceCountOpHandle() { - if (IsStreamGarabageCollector()) { - auto gpu_place = boost::get(dev_ctx_->GetPlace()); - platform::SetDeviceId(gpu_place.device); - PADDLE_ENFORCE(cudaEventDestroy(event_)); - } - } - - std::string Name() const override { return "reference_count"; } - - void AddVar(const std::string &name) { - auto it = var_names_.find(name); - if (it != var_names_.end()) - ++(it->second); - else - var_names_[name] = 1; - } - - protected: - void RunImpl() override { - auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); - std::vector tensors; - for (auto &pair : var_names_) { - auto &name = pair.first; - auto it = ref_cnts_->find(name); - if (it == ref_cnts_->end()) continue; - - auto *var = exec_scope->FindVar(name); - if (var == nullptr) continue; - - if (var->IsType()) { - if (it->second.fetch_sub(pair.second) <= pair.second) { - tensors.emplace_back(var->GetMutable()); - } - } else if (var->IsType()) { - if (it->second.fetch_sub(pair.second) <= pair.second) { - tensors.emplace_back( - var->GetMutable()->mutable_value()); - } - } - } - - if (!tensors.empty()) { - ClearTensors(tensors); - } - } - - private: - void ClearTensors(const std::vector &tensors) { - auto *gc = dynamic_cast *>(gc_); - if (gc != nullptr) { - auto compute_stream = dev_ctx_->stream(); - auto callback_stream = gc->stream(); - auto callback_func = [=]() { - PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); - PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); - }; - gc_->Add(tensors, callback_func); - } else { - gc_->Add(tensors); - } - } - - bool IsStreamGarabageCollector() const { - return dynamic_cast *>(gc_) != nullptr; - } - - const Scope *scope_; - platform::CUDADeviceContext *dev_ctx_; - std::unordered_map var_names_; - GarbageCollector *gc_; // not own - AtomicReferenceCountMap *ref_cnts_; // not own - cudaEvent_t event_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 08783fb5f8b..f094c7afa9f 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -17,184 +17,96 @@ #include #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/reference_count_pass.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { namespace details { -static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { - std::queue queue; - queue.push(var_in); +static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( + OpHandleBase *op, size_t scope_idx) { + std::queue q; + std::unordered_set visited; + q.push(op); do { - auto *var = queue.front(); - queue.pop(); - for (auto *op : var->PendingOps()) { - auto *compute_op = dynamic_cast(op); - if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) { - return compute_op; - } - for (auto *out_var : op->Outputs()) { - queue.push(out_var); + auto *op = q.front(); + q.pop(); + auto *compute_op = dynamic_cast(op); + if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { + return compute_op; + } + for (auto *out_var : op->Outputs()) { + for (auto *pending_op : out_var->PendingOps()) { + if (visited.count(pending_op)) continue; + visited.insert(pending_op); } } - } while (!queue.empty()); + } while (!q.empty()); return nullptr; } -static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out, - ir::Graph *graph) { - auto it = std::find_if( - in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) { - return dynamic_cast(var) != nullptr; - }); - - if (it != in->Outputs().end()) { - out->AddInput(*it); - } else { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dep_var); - in->AddOutput(dep_var); - out->AddInput(dep_var); - } -} - std::unique_ptr ReferenceCountPass::ApplyImpl( std::unique_ptr graph) const { - auto &ref_cnts = Get(kGlobalReferenceCount); - auto &cur_ref_cnts = Get(kCurReferenceCount); - auto &gcs = Get(kGarbageCollector); - - // It is not easy to find the right reference counts of varaibles in graph - // Step 1: Find all variables in computation ops - // Step 2: Find all variables in non-computation ops which refers to variables - // in computation ops - std::unordered_set names; - std::unordered_map - compute_ref_cnt_map; - - auto get_ref_cnts_from_compute_op = [&]( - OpHandleBase *op, const std::vector &vars) { - std::vector var_names_in_op; - auto *compute_op = dynamic_cast(op); - if (compute_op == nullptr || - !platform::is_gpu_place(compute_op->GetPlace())) - return var_names_in_op; - auto place = boost::get(compute_op->GetPlace()); - for (VarHandleBase *var_handle_base : vars) { - auto *var_handle = dynamic_cast(var_handle_base); - if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; - - if (!platform::is_gpu_place(var_handle->place_) || - boost::get(var_handle->place_) != place) + auto &vars = graph->Get(kGraphVars); + auto &ref_cnts = Get>(kGlobalReferenceCount); + auto &last_live_ops_of_vars = + Get>(kLastLiveOpsOfVars); + + last_live_ops_of_vars = std::vector(vars.size()); + ref_cnts = std::vector(vars.size()); + + for (size_t i = 0; i < vars.size(); ++i) { + for (auto &name_var_pair : vars[i]) { + if (name_var_pair.second.empty()) continue; + auto *last_ver_var = name_var_pair.second.back(); + + VarDesc *var_desc = nullptr; + std::find_if(name_var_pair.second.rbegin(), name_var_pair.second.rend(), + [&](VarHandle *var_handle) -> bool { + var_desc = var_handle->Node()->Var(); + return var_desc != nullptr; + }); + + if (var_desc == nullptr || var_desc->Persistable()) { continue; - - VarDesc *var_desc = var_handle->Node()->Var(); - auto var_name = var_handle->Node()->Name(); - - // This is weird but there is really some variables without var_desc - // in computation_op - if (var_desc == nullptr) { - var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name); - if (var_desc == nullptr) continue; } - if (var_desc->Persistable()) continue; auto var_type = var_desc->Proto()->type().type(); if (var_type != proto::VarType::LOD_TENSOR && - var_type != proto::VarType::SELECTED_ROWS) { + var_type != proto::VarType::SELECTED_ROWS && + var_type != proto::VarType::LOD_TENSOR_ARRAY) { continue; } - // compute op only runs in one device - if (ref_cnts[place.device]->count(var_name)) - ++(*ref_cnts[place.device])[var_name]; - else - (*ref_cnts[place.device])[var_name] = 1; - - names.insert(var_name); - var_names_in_op.push_back(var_name); - } - return var_names_in_op; - }; - - auto update_ref_cnts_from_non_compute_op = [&]( - OpHandleBase *op, const std::vector &vars) { - if (dynamic_cast(op) != nullptr) return; - for (VarHandleBase *var_handle_base : vars) { - auto *var_handle = dynamic_cast(var_handle_base); - if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; - - auto var_name = var_handle->Node()->Name(); - auto var_place = var_handle->place_; - if (!platform::is_gpu_place(var_place)) continue; - auto place = boost::get(var_place); - if (names.count(var_name) == 0) continue; - if (ref_cnts.count(place.device) && - ref_cnts[place.device]->count(var_name)) { - ++(*ref_cnts[place.device])[var_name]; - - auto *next_compute_op = FindNextComputationOpHandle(var_handle); - if (next_compute_op != nullptr) { - if (compute_ref_cnt_map.count(next_compute_op)) { - compute_ref_cnt_map[next_compute_op]->AddVar(var_name); - VLOG(5) << "Add reference count of " << var_name << " to Operator " - << next_compute_op->Name(); - } else { - // Create new reference_count_op_handle - ir::Node *ref_cnt_node = graph->CreateEmptyNode( - "reference_count", ir::Node::Type::kOperation); - auto *ref_cnt_handle = new ReferenceCountOpHandle( - ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, - gcs[place.device].get(), cur_ref_cnts[place.device].get()); - AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); - compute_ref_cnt_map[next_compute_op] = ref_cnt_handle; - } + std::unordered_set last_live_op; + auto add_last_live_op = [&](OpHandleBase *op) { + auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i); + if (compute_op) { + last_live_op.insert(compute_op); + } + }; + const std::string &var_name = name_var_pair.first; + auto &pending_ops = last_ver_var->PendingOps(); + if (pending_ops.empty()) { + auto *generated_op = last_ver_var->GeneratedOp(); + if (generated_op) { + ref_cnts[i].emplace(var_name, 1); + add_last_live_op(generated_op); + } + } else { + ref_cnts[i].emplace(var_name, pending_ops.size()); + for (auto *pending_op : pending_ops) { + add_last_live_op(pending_op); } } - } - }; - auto all_ops = ir::FilterByNodeWrapper(*graph); - for (auto &op : all_ops) { - auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); - auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); - if (in_var_names.empty() && out_var_names.empty()) continue; - in_var_names.insert(in_var_names.end(), out_var_names.begin(), - out_var_names.end()); - auto *compute_op = dynamic_cast(op); - auto place = boost::get(compute_op->GetPlace()); - ir::Node *ref_cnt_node = - graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); - auto *ref_cnt_handle = new ReferenceCountOpHandle( - ref_cnt_node, compute_op->GetScope(), place, in_var_names, - gcs[place.device].get(), cur_ref_cnts[place.device].get()); - AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); - compute_ref_cnt_map[compute_op] = ref_cnt_handle; - } - - for (auto &op : all_ops) { - update_ref_cnts_from_non_compute_op(op, op->Inputs()); - update_ref_cnts_from_non_compute_op(op, op->Outputs()); - } - - std::vector new_all_ops; - new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); - for (auto &op : all_ops) { - new_all_ops.emplace_back(std::move(op)); - auto it = compute_ref_cnt_map.find(new_all_ops.back()); - if (it != compute_ref_cnt_map.end()) { - // Add LeafNode to ReferenceCountOpHandle - auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dummy_leaf); - it->second->AddOutput(dummy_leaf); - new_all_ops.emplace_back(std::move(it->second)); + last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op)); } } - - all_ops.swap(new_all_ops); return graph; } @@ -205,5 +117,4 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( REGISTER_PASS(reference_count_pass, paddle::framework::details::ReferenceCountPass) .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) - .RequirePassAttr(paddle::framework::details::kCurReferenceCount) - .RequirePassAttr(paddle::framework::details::kGarbageCollector); + .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars); diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h index 7081280b060..bcbef027354 100644 --- a/paddle/fluid/framework/details/reference_count_pass.h +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/framework/details/reference_count_op_handle.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" @@ -22,10 +21,6 @@ namespace paddle { namespace framework { namespace details { -constexpr char kGlobalReferenceCount[] = "reference_count"; -constexpr char kCurReferenceCount[] = "current_reference_count"; -constexpr char kGarbageCollector[] = "garbage_collector"; - class ReferenceCountPass : public ir::Pass { protected: std::unique_ptr ApplyImpl( diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.h b/paddle/fluid/framework/details/reference_count_pass_helper.h new file mode 100644 index 00000000000..77846f7bdfc --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass_helper.h @@ -0,0 +1,49 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace framework { +namespace details { + +class ComputationOpHandle; + +using ReferenceCountMap = std::unordered_map; + +using AtomicReferenceCountMap = + std::unordered_map>; + +using GarbageCollectorList = + std::vector>>; + +const char kGlobalReferenceCount[] = "reference_count"; +const char kCurReferenceCount[] = "current_reference_count"; +const char kGarbageCollector[] = "garbage_collector"; + +using LastLiveOpsOfVars = + std::unordered_map>; +const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index e5b1eaa7318..f1bf6542a30 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -18,9 +18,6 @@ #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/platform/profiler.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/details/reference_count_op_handle.h" -#endif namespace paddle { namespace framework { @@ -33,7 +30,11 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( underlying_executor_(std::move(underlying_executor)), local_scopes_(std::move(local_scopes)), var_infos_(std::move(var_infos)), - places_(std::move(places)) {} + places_(std::move(places)) { + if (Graph().Has(details::kGarbageCollector)) { + gc_ = &(Graph().Get(details::kGarbageCollector)); + } +} FeedFetchList ScopeBufferedSSAGraphExecutor::Run( const std::vector &fetch_tensors) { @@ -69,27 +70,16 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); drop_scope_counter_ += 1; -#ifdef PADDLE_WITH_CUDA - const std::string gc_name = "garbage_collector"; - DeviceGarbageCollectorMap *gc = - Graph().Has(gc_name) ? &(Graph().Get(gc_name)) - : nullptr; -#endif - if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams - for (auto p : places_) { - platform::DeviceContextPool::Instance().Get(p)->Wait(); -#ifdef PADDLE_WITH_CUDA - if (gc != nullptr && platform::is_gpu_place(p)) { - auto gpu_place = boost::get(p); - auto &gc_at_place = gc->at(gpu_place.device); - gc_at_place->Wait(); - gc_at_place->Reset(); + for (size_t i = 0; i < places_.size(); ++i) { + platform::DeviceContextPool::Instance().Get(places_[i])->Wait(); + if (gc_) { + (*gc_)[i]->Wait(); + (*gc_)[i]->Reset(); } -#endif } for (auto &scope : local_scopes_) { auto &local_scope = diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 5e87e0bf50b..ce3061d6e61 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -21,9 +21,11 @@ #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/place.h" + namespace paddle { namespace framework { namespace details { @@ -55,6 +57,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::vector var_infos_; std::vector places_; + + GarbageCollectorList* gc_{nullptr}; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index 818b3334ea4..cbe8f606efe 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -65,7 +65,7 @@ class GarbageCollector { if (clear_deque != nullptr) { callback(); - ClearCallback([=]() { + ClearCallback([clear_deque]() { for (auto *obj : *clear_deque) obj->clear(); }); } @@ -109,7 +109,6 @@ class DefaultStreamGarbageCollector : public GarbageCollector { } void Wait() const override { - this->dev_ctx_->Wait(); static_cast(this->dev_ctx_) ->WaitStreamCallback(); } @@ -127,14 +126,14 @@ class StreamGarbageCollector : public GarbageCollector { StreamGarbageCollector(const platform::CUDAPlace &place, size_t max_memory_size) : GarbageCollector(place, max_memory_size) { - PADDLE_ENFORCE(cudaSetDevice(place.device)); + platform::SetDeviceId(place.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); callback_manager_.reset(new platform::StreamCallbackManager(stream_)); } ~StreamGarbageCollector() { auto place = boost::get(this->dev_ctx_->GetPlace()); - PADDLE_ENFORCE(cudaSetDevice(place.device)); + platform::SetDeviceId(place.device); PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); } @@ -148,8 +147,11 @@ class StreamGarbageCollector : public GarbageCollector { cudaStream_t stream() const { return stream_; } protected: + // ClearCallback and Wait()/Reset() cannot be call in multiple threads + // But it is not important, because they would not be called in multiple + // threads + // either in Executor or ParallelExecutor void ClearCallback(const std::function &callback) override { - std::lock_guard guard(this->mutex_); callback_manager_->AddCallback(callback); } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 947c934f0ff..7a2560c14df 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -73,14 +73,21 @@ class Graph { } bool Has(const std::string &attr_name) const { - return attrs_.find(attr_name) != attrs_.end(); + return attrs_.count(attr_name) > 0; } template AttrType &Get(const std::string &attr_name) const { PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", attr_name); - return *boost::any_cast(attrs_.at(attr_name)); + try { + return *boost::any_cast(attrs_.at(attr_name)); + } catch (boost::bad_any_cast &) { + PADDLE_THROW( + "Invalid attribute type of %s error, expected: %s, actual: %s", + attr_name, typeid(AttrType *).name(), + attrs_.at(attr_name).type().name()); + } } template diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index a3559247db6..27746ff1453 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -51,11 +51,18 @@ class Pass { AttrType &Get(const std::string &attr_name) const { PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), "%s attr not registered for pass.", attr_name); - return *boost::any_cast(attrs_.at(attr_name)); + try { + return *boost::any_cast(attrs_.at(attr_name)); + } catch (boost::bad_any_cast &) { + PADDLE_THROW( + "Invalid attribute type of %s error, expected: %s, actual: %s", + attr_name, typeid(AttrType *).name(), + attrs_.at(attr_name).type().name()); + } } bool Has(const std::string &attr_name) const { - return attrs_.find(attr_name) != attrs_.end(); + return attrs_.count(attr_name) > 0; } void Erase(const std::string &attr_name) { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b98408ee772..e71f93beefc 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -49,6 +50,15 @@ class ParallelExecutorPrivate { } } } + + void ResetRuntimeReferenceCount() { + for (size_t i = 0; i < rt_ref_cnts_.size(); ++i) { + for (auto &pair : rt_ref_cnts_[i]) { + rt_cur_ref_cnts_[i][pair.first] = pair.second; + } + } + } + std::vector places_; std::vector local_scopes_; Scope *global_scope_; // not owned @@ -60,6 +70,13 @@ class ParallelExecutorPrivate { bool own_local_scope_; bool use_cuda_; bool use_all_reduce_; + + // rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then + // keeps unchanged + // Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_ + std::vector rt_ref_cnts_; + std::vector rt_cur_ref_cnts_; + details::GarbageCollectorList gcs_; }; std::vector &ParallelExecutor::GetLocalScopes() { @@ -128,35 +145,56 @@ ParallelExecutor::ParallelExecutor( std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, params, member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); +#else + std::unique_ptr graph = + build_strategy.Apply(main_program, member_->places_, loss_var_name, + params, member_->local_scopes_, member_->use_cuda_); +#endif auto max_memory_size = GetEagerDeletionThreshold(); if (max_memory_size >= 0) { - for (auto &place : member_->places_) { - if (!platform::is_gpu_place(place)) continue; - auto gpu_place = boost::get(place); - if (gcs_[gpu_place.device] == nullptr) { - ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap()); - cur_ref_cnts_[gpu_place.device].reset( - new details::AtomicReferenceCountMap()); - gcs_[gpu_place.device].reset( - new StreamGarbageCollector(gpu_place, max_memory_size)); + size_t place_num = member_->places_.size(); + for (size_t i = 0; i < place_num; ++i) { + auto &place = member_->places_[i]; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place)) { + member_->gcs_.emplace_back(new StreamGarbageCollector( + boost::get(place), max_memory_size)); + VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; + } else if (platform::is_cpu_place(place)) { +#endif + member_->gcs_.emplace_back(new CPUGarbageCollector( + boost::get(place), max_memory_size)); + VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; +#ifdef PADDLE_WITH_CUDA } - } - if (!gcs_.empty()) { - auto ref_cnt_pass = - ir::PassRegistry::Instance().Get("reference_count_pass"); - ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); - ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); - ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); - graph = ref_cnt_pass->Apply(std::move(graph)); - graph->SetNotOwned("garbage_collector", &gcs_); +#endif } } -#else - std::unique_ptr graph = - build_strategy.Apply(main_program, member_->places_, loss_var_name, - params, member_->local_scopes_, member_->use_cuda_); -#endif + + if (!member_->gcs_.empty()) { + std::vector last_live_ops_of_vars; + + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, + &(member_->rt_ref_cnts_)); + ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, + &last_live_ops_of_vars); + VLOG(10) << "ReferenceCountPass Applied"; + graph = ref_cnt_pass->Apply(std::move(graph)); + + auto eager_deletion_pass = + ir::PassRegistry::Instance().Get("eager_deletion_pass"); + eager_deletion_pass->SetNotOwned(details::kCurReferenceCount, + &(member_->rt_cur_ref_cnts_)); + eager_deletion_pass->SetNotOwned(details::kGarbageCollector, + &(member_->gcs_)); + eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, + &last_live_ops_of_vars); + graph = eager_deletion_pass->Apply(std::move(graph)); + VLOG(10) << "EagerDeletionPass Applied"; + } // Step 3. Create vars in each scope. Passes may also create new vars. // skip control vars and empty vars @@ -271,18 +309,16 @@ void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { platform::RecordBlock b(0); -#ifdef PADDLE_WITH_CUDA - if (!gcs_.empty()) { - ResetReferenceCount(); - for (auto &pair : cur_ref_cnts_) { - auto &name_map = *(pair.second); + if (!member_->gcs_.empty()) { + member_->ResetRuntimeReferenceCount(); + size_t n = member_->rt_ref_cnts_.size(); + for (size_t i = 0; i < n; ++i) { for (auto &fetch_name : fetch_tensors) { - name_map.erase(fetch_name); + member_->rt_cur_ref_cnts_[i].erase(fetch_name); } - name_map.erase(fetched_var_name); + member_->rt_cur_ref_cnts_[i].erase(fetched_var_name); } } -#endif auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; @@ -326,13 +362,11 @@ ParallelExecutor::~ParallelExecutor() { for (auto &p : member_->places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } - // member_ must be destructed before gcs_ since the destructor of - // ReferenceCountOpHandle use raw pointers of gcs_ inside. - member_.reset(); + delete member_; } } // namespace framework } // namespace paddle -#ifdef PADDLE_WITH_CUDA + USE_PASS(reference_count_pass); -#endif +USE_PASS(eager_deletion_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ef09b98b2aa..1fc17a0d64d 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include #include #include @@ -29,10 +28,6 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/details/reference_count_pass.h" -#endif - namespace paddle { namespace framework { @@ -75,24 +70,7 @@ class ParallelExecutor { private: void BCastParamsToDevices(const std::unordered_set &vars) const; - std::unique_ptr member_; - -#ifdef PADDLE_WITH_CUDA - // ref_cnts_ is only initialized when ParallelExecutor constructs, and then - // keeps unchanged - // Before each iteration, cur_ref_cnts_ is reset to ref_cnts_ - details::DeviceReferenceCountMap ref_cnts_; - details::AtomicDeviceReferenceCountMap cur_ref_cnts_; - details::DeviceGarbageCollectorMap gcs_; - - void ResetReferenceCount() { - for (auto &pair1 : ref_cnts_) { - for (auto &pair2 : *(pair1.second)) { - (*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second; - } - } - } -#endif + ParallelExecutorPrivate *member_; }; } // namespace framework diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 93cb5eb2dc0..23c7ebe8422 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -56,9 +56,16 @@ ELSE() set(MKLDNN_CTX_DEPS) ENDIF() +nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) +IF(WITH_GPU) + set(STREAM_CALLBACK_DEPS stream_callback_manager) +ELSE() + set(STREAM_CALLBACK_DEPS) +ENDIF() + # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc new file mode 100644 index 00000000000..ae915365f8c --- /dev/null +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/stream_callback_manager.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +struct StreamCallbackContext { + inline StreamCallbackContext(const StreamCallbackManager *manager, + std::function callback) + : manager_(manager), callback_(std::move(callback)) {} + + const StreamCallbackManager *manager_; // do not own + std::function callback_; +}; + +StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream) + : stream_(stream), thread_pool_(new ::ThreadPool(1)) {} + +void StreamCallbackManager::AddCallback(std::function callback) const { + auto *stream_callback_context = + new StreamCallbackContext(this, std::move(callback)); +#if CUDA_VERSION >= 10000 + PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, + StreamCallbackManager::StreamCallbackFunc, + stream_callback_context)); +#else + PADDLE_ENFORCE( + cudaStreamAddCallback(stream_, StreamCallbackManager::StreamCallbackFunc, + stream_callback_context, 0)); +#endif +} + +void StreamCallbackManager::Wait() const { + thread_pool_.reset(new ::ThreadPool(1)); +} + +#if CUDA_VERSION >= 10000 +void CUDART_CB StreamCallbackManager::StreamCallbackFunc(void *user_data) +#else +void CUDART_CB StreamCallbackManager::StreamCallbackFunc(cudaStream_t stream, + cudaError_t status, + void *user_data) +#endif +{ + auto *callback_context_ptr = + reinterpret_cast(user_data); + callback_context_ptr->manager_->thread_pool_->enqueue( + [callback_context_ptr]() { + std::unique_ptr callback_context( + callback_context_ptr); + callback_context->callback_(); + }); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h index ed8734c98cb..eac4806d137 100644 --- a/paddle/fluid/platform/stream_callback_manager.h +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -19,66 +19,29 @@ #include #include #include -#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace platform { -class StreamCallbackManager; - -struct StreamCallbackContext { - template - inline StreamCallbackContext(const StreamCallbackManager *manager, - Callback &&callback) - : manager_(manager), callback_(callback) {} - - const StreamCallbackManager *manager_; // do not own - std::function callback_; -}; - +// NOTE(zjl): clean StreamCallback to make compilation faster class StreamCallbackManager { public: - explicit inline StreamCallbackManager(cudaStream_t stream = nullptr) - : stream_(stream), thread_pool_(new ThreadPool(1)) {} + explicit StreamCallbackManager(const cudaStream_t stream); - template - inline void AddCallback(Callback &&callback) const { - auto *stream_callback_context = - new StreamCallbackContext(this, std::forward(callback)); -#if CUDA_VERSION >= 10000 - PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, - StreamCallbackManager::StreamCallbackFunc, - stream_callback_context)); // NOLINT -#else - PADDLE_ENFORCE(cudaStreamAddCallback( - stream_, StreamCallbackManager::StreamCallbackFunc, - stream_callback_context, 0)); // NOLINT -#endif - } + void AddCallback(std::function callback) const; - void Wait() const { thread_pool_.reset(new ThreadPool(1)); } + void Wait() const; private: const cudaStream_t stream_; - mutable std::unique_ptr thread_pool_; + mutable std::unique_ptr<::ThreadPool> thread_pool_; -// cudaStreamCallback cannot call CUDA API inside, so we have to use -// thread_pool here #if CUDA_VERSION >= 10000 - static void CUDART_CB StreamCallbackFunc(void *user_data) + static void CUDART_CB StreamCallbackFunc(void *user_data); #else static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, - cudaError_t status, void *user_data) + cudaError_t status, void *user_data); #endif - { - auto *callback_context_ptr = - reinterpret_cast(user_data); - callback_context_ptr->manager_->thread_pool_->enqueue([=]() { - std::unique_ptr callback_context( - callback_context_ptr); - callback_context->callback_(); - }); - } }; } // namespace platform -- GitLab