diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index e0a3ef5a9c6c53c42ebea1a41cac0d18a77781b2..a9dddede784ded91143dbc78bb4f3277b811613b 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1,5 +1,6 @@ cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) +cc_library(op_handle_graph SRCS op_handle_graph.cc DEPS op_handle_base) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) @@ -28,6 +29,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) +cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_handle_graph multi_devices_helper) + if(WITH_GPU) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) @@ -37,9 +40,9 @@ cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_ scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) if(WITH_GPU) - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass modify_op_lock_and_record_event_pass) else() - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto modify_op_lock_and_wait_pass) endif() cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index b6282debdb4eb6b1f29c39e54ac4f3e2296838da..690d37211ec0de56c5ffbdeec551ad3c3d0c91ec 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -20,18 +20,26 @@ namespace paddle { namespace framework { namespace details { ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, - platform::Place place) + platform::Place place, + size_t scope_idx) : OpHandleBase(node), op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), - place_(place) {} + place_(place), + scope_idx_(scope_idx) {} void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); - this->RunAndRecordEvent([this] { + auto run_func = [this]() { op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); - }); + }; + + if (is_lock_and_record_event_free_) { + run_func(); + } else { + this->RunAndRecordEvent(run_func); + } } bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index e98f1ab148db083ac63a1afd43e334fbfae62539..fce9dc18492d1cab7eecbf1bdb154934547ae6d5 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,7 +28,8 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, + size_t scope_idx); std::string Name() const override; @@ -36,6 +37,14 @@ struct ComputationOpHandle : public OpHandleBase { const platform::Place &GetPlace() const { return place_; } + size_t GetScopeIdx() const { return scope_idx_; } + + OperatorBase &GetOp() { return *op_; } + + const OperatorBase &GetOp() const { return *op_; } + + void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } + protected: void RunImpl() override; @@ -45,6 +54,8 @@ struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; + size_t scope_idx_{0}; + bool is_lock_and_record_event_free_{false}; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed07d84fd643dd1aa711b227a0c4186985dd078b --- /dev/null +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/op_handle_graph.h" + +namespace paddle { +namespace framework { +namespace details { + +static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) { + return dynamic_cast(op); +} + +static bool IsLockAndRecordEventFreeComputationOpHandle( + ComputationOpHandle *op, const OpHandleGraph &graph) { + for (auto &pending_op : graph.PendingOps(op)) { + auto *tmp = ConvertToComputationOpHandle(pending_op); + if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { + return false; + } + } + return true; +} + +std::unique_ptr ModifyOpLockAndRecordEventPass::ApplyImpl( + std::unique_ptr ir_graph) const { + auto &all_ops = ir_graph->Get(kGraphOps); + OpHandleGraph graph(all_ops); + for (auto &op : all_ops) { + auto *compute_op = ConvertToComputationOpHandle(op.get()); + if (compute_op == nullptr) continue; + bool is_lock_and_record_event_free = + IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph); + compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free); + if (is_lock_and_record_event_free) { + VLOG(10) << "Set is_lock_and_record_event_free be true in op " + << compute_op->DebugString(); + } + } + return ir_graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(modify_op_lock_and_record_event_pass, + paddle::framework::details::ModifyOpLockAndRecordEventPass); diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..b54e1b318be95e1e0abf6830f8c918895df02718 --- /dev/null +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class ModifyOpLockAndRecordEventPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 134fcee826715672a6e021e9bf694bb771ebb830..fb51cfdd19be6edcb6280045bd814f28f352897c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -513,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), - local_scopes_[dev_id], places_[dev_id])); + local_scopes_[dev_id], places_[dev_id], dev_id)); CreateOpHandleIOs(result, node, dev_id); } @@ -630,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get(kGraphOps).emplace_back( - new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); + result->Get(kGraphOps).emplace_back(new ComputationOpHandle( + result->CreateOpNode(node->Op()), s, p, scope_idx)); CreateOpHandleIOs(result, node, scope_idx); } } diff --git a/paddle/fluid/framework/details/op_handle_graph.cc b/paddle/fluid/framework/details/op_handle_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e70305cec04c43b05481f69401f9957b74c2f4b --- /dev/null +++ b/paddle/fluid/framework/details/op_handle_graph.cc @@ -0,0 +1,294 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/op_handle_graph.h" +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +OpHandleGraph::OpHandleGraph( + const std::vector> &ops) { + BuildGraph(ops); +} + +void OpHandleGraph::BuildGraph( + const std::vector> &ops) { + for (auto &op : ops) { + preceding_ops_[op.get()]; + pending_ops_[op.get()]; + for (auto &var : op->Outputs()) { + for (auto &pending_op : var->PendingOps()) { + preceding_ops_[pending_op].insert(op.get()); + pending_ops_[op.get()].insert(pending_op); + } + } + } + PADDLE_ENFORCE( + preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(), + "There are duplicate ops in graph."); +} + +size_t OpHandleGraph::OpNumber() const { return preceding_ops_.size(); } + +std::unordered_set OpHandleGraph::AllOps() const { + std::unordered_set ret; + for (auto &pair : preceding_ops_) { + ret.insert(pair.first); + } + return ret; +} + +bool OpHandleGraph::HasOp(OpHandleBase *op) const { + return preceding_ops_.count(op) != 0; +} + +void OpHandleGraph::EnforceHasOp(OpHandleBase *op) const { + PADDLE_ENFORCE(HasOp(op), "Cannot found op %s in OpHandleGraph", + op == nullptr ? "nullptr" : op->DebugString()); +} + +const std::unordered_set &OpHandleGraph::PrecedingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + return preceding_ops_.at(op); +} + +const std::unordered_set &OpHandleGraph::PendingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + return pending_ops_.at(op); +} + +std::vector> OpHandleGraph::AllPrecedingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + std::queue queue[2]; + int cur = 0; + std::unordered_set visited_ops; + std::vector> ret; + for (auto &tmp : preceding_ops_.at(op)) { + queue[cur].push(tmp); + visited_ops.insert(tmp); + } + + while (!queue[cur].empty()) { + std::unordered_set cur_level_ops; + auto *tmp = queue[cur].front(); + queue[cur].pop(); + for (auto &preceding_op : preceding_ops_.at(tmp)) { + if (visited_ops.count(preceding_op)) { + continue; + } else { + queue[1 - cur].push(preceding_op); + cur_level_ops.insert(preceding_op); + visited_ops.insert(preceding_op); + } + } + if (!cur_level_ops.empty()) { + ret.emplace_back(std::move(cur_level_ops)); + } + cur = 1 - cur; + } + return ret; +} + +std::vector> OpHandleGraph::AllPendingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + std::queue queue[2]; + int cur = 0; + std::unordered_set visited_ops; + std::vector> ret; + for (auto &tmp : preceding_ops_.at(op)) { + queue[cur].push(tmp); + visited_ops.insert(tmp); + } + + while (!queue[cur].empty()) { + std::unordered_set cur_level_ops; + auto *tmp = queue[cur].front(); + queue[cur].pop(); + for (auto &next_op : pending_ops_.at(tmp)) { + if (visited_ops.count(next_op)) { + continue; + } else { + queue[1 - cur].push(next_op); + cur_level_ops.insert(next_op); + visited_ops.insert(next_op); + } + } + if (!cur_level_ops.empty()) { + ret.emplace_back(std::move(cur_level_ops)); + } + cur = 1 - cur; + } + return ret; +} + +OpHandleGraph::Relation OpHandleGraph::RelationBetween( + OpHandleBase *op1, OpHandleBase *op2) const { + EnforceHasOp(op1); + EnforceHasOp(op2); + if (op1 == op2) { + return kSame; + } else if (IsBeforeOrSameImpl(op1, op2)) { + return kBefore; + } else if (IsBeforeOrSameImpl(op2, op1)) { + return kAfter; + } else { + return kNoDeps; + } +} + +bool OpHandleGraph::IsSame(OpHandleBase *op1, OpHandleBase *op2) const { + EnforceHasOp(op1); + EnforceHasOp(op2); + return op1 == op2; +} + +bool OpHandleGraph::IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const { + EnforceHasOp(op1); + EnforceHasOp(op2); + return IsBeforeOrSameImpl(op1, op2); +} + +bool OpHandleGraph::IsBefore(OpHandleBase *op1, OpHandleBase *op2) const { + EnforceHasOp(op1); + EnforceHasOp(op2); + return op1 != op2 && IsBeforeOrSameImpl(op1, op2); +} + +bool OpHandleGraph::IsBeforeOrSameImpl(OpHandleBase *op1, + OpHandleBase *op2) const { + std::queue queue; + // BFS + queue.push(op1); + do { + auto *op = queue.front(); + queue.pop(); + if (op == op2) return true; + for (auto &pending_op : pending_ops_.at(op)) { + queue.push(pending_op); + } + } while (!queue.empty()); + return false; +} + +bool OpHandleGraph::IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const { + EnforceHasOp(op1); + EnforceHasOp(op2); + return IsBeforeOrSameImpl(op2, op1); +} + +bool OpHandleGraph::IsAfter(OpHandleBase *op1, OpHandleBase *op2) const { + return IsBefore(op2, op1); +} + +bool OpHandleGraph::IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const { + return RelationBetween(op1, op2) == kNoDeps; +} + +std::unordered_set OpHandleGraph::NoPendingOpSet() const { + std::unordered_set ret; + for (auto &pair : pending_ops_) { + if (pair.second.empty()) ret.insert(pair.first); + } + return ret; +} + +std::unordered_set OpHandleGraph::NoPrecedingOpSet() const { + std::unordered_set ret; + for (auto &pair : preceding_ops_) { + if (pair.second.empty()) ret.insert(pair.first); + } + return ret; +} + +OpHandleBase *OpHandleGraph::NearestCommonParent(OpHandleBase *op1, + OpHandleBase *op2) const { + EnforceHasOp(op1); + EnforceHasOp(op2); + // FIXME(zjl): A brute-force O(2*n) algorithm here + // First, BFS all preceding_ops of op1 and record them in set S + // Second, BFS all preceding_ops of op2 and found whether it is in set S + std::unordered_set all_preceding_ops; + std::queue queue; + queue.push(op1); + do { + auto *op = queue.front(); + queue.pop(); + all_preceding_ops.insert(op); + for (auto &preceding_op : preceding_ops_.at(op)) { + queue.push(preceding_op); + } + } while (!queue.empty()); + + queue.push(op2); + do { + auto *op = queue.front(); + queue.pop(); + if (all_preceding_ops.count(op)) return op; + for (auto &preceding_op : preceding_ops_.at(op)) { + queue.push(preceding_op); + } + } while (!queue.empty()); + return nullptr; +} + +OpHandleBase *OpHandleGraph::NearestCommonParentAfter(OpHandleBase *op, + OpHandleBase *op1, + OpHandleBase *op2) const { + EnforceHasOp(op); + EnforceHasOp(op1); + EnforceHasOp(op2); + std::unordered_map all_preceding_ops; + int max_depth = -1; + std::queue> queue; + queue.push(std::make_pair(op1, 0)); + do { + auto tmp = queue.front(); + queue.pop(); + all_preceding_ops.insert(tmp); + if (tmp.first == op1) { + max_depth = tmp.second; + break; + } + for (auto &preceding_op : preceding_ops_.at(tmp.first)) { + queue.push(std::make_pair(preceding_op, tmp.second + 1)); + } + } while (!queue.empty()); + + if (max_depth == -1) { + return nullptr; + } + + std::queue queue2; + queue2.push(op2); + do { + auto *tmp = queue2.front(); + queue2.pop(); + if (all_preceding_ops.count(tmp) && + (tmp == op || all_preceding_ops[tmp] < max_depth)) { + return tmp; + } + } while (!queue2.empty()); + return nullptr; +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_graph.h b/paddle/fluid/framework/details/op_handle_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..803edce048e7aee7a70021ee75fcd66576bf4a42 --- /dev/null +++ b/paddle/fluid/framework/details/op_handle_graph.h @@ -0,0 +1,87 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" + +namespace paddle { +namespace framework { +namespace details { + +class OpHandleGraph { + public: + enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 }; + + explicit OpHandleGraph(const std::vector> &ops); + + size_t OpNumber() const; + + std::unordered_set AllOps() const; + + const std::unordered_set &PrecedingOps( + OpHandleBase *op) const; + + const std::unordered_set &PendingOps(OpHandleBase *op) const; + + std::vector> AllPrecedingOps( + OpHandleBase *op) const; + + std::vector> AllPendingOps( + OpHandleBase *op) const; + + bool HasOp(OpHandleBase *op) const; + + Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const; + + bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const; + + bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const; + + bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const; + + bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const; + + bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const; + + bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const; + + OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const; + + // Find an operator that is after op and before op1, op2 + OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1, + OpHandleBase *op2) const; + + std::unordered_set NoPendingOpSet() const; + + std::unordered_set NoPrecedingOpSet() const; + + private: + void BuildGraph(const std::vector> &ops); + void EnforceHasOp(OpHandleBase *op) const; + bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const; + + std::unordered_map> + preceding_ops_; + std::unordered_map> + pending_ops_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h index fc479a4c4a1e7d5c824d3c202e0cccf743dd52c9..cc4ccfbdfc720284e683a8f3f59a4aa57a3a9eb1 100644 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase { dev_ctx_ = static_cast( platform::DeviceContextPool::Instance().Get(place)); if (IsStreamGarabageCollector()) { - PADDLE_ENFORCE(cudaSetDevice(place.device)); + platform::SetDeviceId(place.device); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } @@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ~ReferenceCountOpHandle() { if (IsStreamGarabageCollector()) { auto gpu_place = boost::get(dev_ctx_->GetPlace()); - PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); + platform::SetDeviceId(gpu_place.device); PADDLE_ENFORCE(cudaEventDestroy(event_)); } } diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 2d1f688d64ece3322e253b0c070264b9eb73d678..0b994ced7f751f056fec076e3dea8d14d0bed991 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { return nullptr; } +static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out, + ir::Graph *graph) { + auto it = std::find_if( + in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) { + return dynamic_cast(var) != nullptr; + }); + + if (it != in->Outputs().end()) { + out->AddInput(*it); + } else { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dep_var); + in->AddOutput(dep_var); + out->AddInput(dep_var); + } +} + std::unique_ptr ReferenceCountPass::ApplyImpl( std::unique_ptr graph) const { auto &ref_cnts = Get(kGlobalReferenceCount); @@ -133,12 +150,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( auto *ref_cnt_handle = new ReferenceCountOpHandle( ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, gcs[place.device].get(), cur_ref_cnts[place.device].get()); - if (next_compute_op->Outputs().empty()) { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - next_compute_op->AddOutput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - } - ref_cnt_handle->AddInput(next_compute_op->Outputs().front()); + AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); } } @@ -160,12 +172,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( auto *ref_cnt_handle = new ReferenceCountOpHandle( ref_cnt_node, compute_op->GetScope(), place, in_var_names, gcs[place.device].get(), cur_ref_cnts[place.device].get()); - if (compute_op->Outputs().empty()) { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - compute_op->AddOutput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - } - ref_cnt_handle->AddInput(compute_op->Outputs().front()); + AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3368ae2ee4cf65b85abf1fcd89dee14f43522e1f..20cb752949ba48e5ba80b985f5085291231301ce 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -156,6 +156,10 @@ ParallelExecutor::ParallelExecutor( params, member_->local_scopes_, member_->use_cuda_); #endif + graph = ir::PassRegistry::Instance() + .Get("modify_op_lock_and_record_event_pass") + ->Apply(std::move(graph)); + // If the loss_var_name is given, the number of graph should be only one. if (loss_var_name.size()) { PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, @@ -319,6 +323,8 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle + +USE_PASS(modify_op_lock_and_record_event_pass); #ifdef PADDLE_WITH_CUDA USE_PASS(reference_count_pass); #endif diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 4a7a6bcf7154d5680de751e3c933be46fb09fd74..c37032bf090a34077f0f706307c07a0c0fd1185d 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int i = 0; i < groups; i++) { auto cudnn_func = [&](void* cudnn_workspace) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( @@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + i * group_offset_out)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } }; @@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. @@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, input_grad_data + i * group_offset_in)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } // ------------------- cudnn conv backward filter --------------------- @@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } } diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 73831611d01b8c5b8d2d9f7f15634a0094e4a608..f44094ca6b7b7f23f2e7593ad79e4e2a6f0d3070 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { int output_offset = output->numel() / output->dims()[0] / groups; int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int g = 0; g < groups; g++) { auto cudnn_func = [&](void* cudnn_workspace) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( @@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + output_offset * g)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } }; @@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { output_grad->numel() / output_grad->dims()[0] / groups; int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. @@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, input_grad_data + input_offset * g)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } @@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data + filter_offset * g)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7d1cf57253819b34fedfb292ad1635650f53f20f..25540c71e0a6588f8ea6ba3bd754ddd67cf5f1b0 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -168,10 +168,7 @@ class CudnnHolder { void RunFunc(const std::function& cudnn_func, size_t required_workspace_len) { std::lock_guard lock(mtx_); - if (required_workspace_len > workspace_len_) { - ReallocateWorkspace(required_workspace_len); - } - cudnn_func(workspace_); + RunFuncImpl(cudnn_func, required_workspace_len); } ~CudnnHolder() { @@ -182,6 +179,16 @@ class CudnnHolder { } private: + std::mutex& Mutex() { return mtx_; } + + void RunFuncImpl(const std::function& cudnn_func, + size_t required_workspace_len) { + if (required_workspace_len > workspace_len_) { + ReallocateWorkspace(required_workspace_len); + } + cudnn_func(workspace_); + } + void ReallocateWorkspace(size_t required_workspace_len) { if (required_workspace_len <= workspace_len_) { return; @@ -195,6 +202,8 @@ class CudnnHolder { workspace_len_ = required_workspace_len; } + friend class CudnnWorkspaceHandle; + cudnnHandle_t cudnn_handle_; void* workspace_; size_t workspace_len_; @@ -205,6 +214,24 @@ class CudnnHolder { std::mutex mtx_; }; +CudnnWorkspaceHandle::CudnnWorkspaceHandle(CudnnHolder* holder) + : holder_(holder) {} + +void CudnnWorkspaceHandle::RunFunc(const std::function& cudnn_func, + size_t required_workspace_len) { + // defer lock when the function is invoked first time + BeginCallGuard(); + holder_->RunFuncImpl(cudnn_func, required_workspace_len); +} + +void CudnnWorkspaceHandle::BeginCallGuard() { + if (!guard_) { + guard_.reset(new std::lock_guard(holder_->Mutex())); + } +} + +void CudnnWorkspaceHandle::EndCallGuard() { guard_.reset(); } + CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place), cudnn_holder_(nullptr) { SetDeviceId(place_.device); @@ -271,6 +298,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_holder_->cudnn_handle(); } +CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { + return CudnnWorkspaceHandle(cudnn_holder_.get()); +} + void CUDADeviceContext::RunCudnnFuncWithWorkspace( const std::function& cudnn_func, size_t workspace_len) const { cudnn_holder_->RunFunc(cudnn_func, workspace_len); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 999bbe00f1659881050cb0dc89570b74b201aca7..0631a098c7561c790f61a3391b23b1644b257a96 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -74,6 +74,33 @@ struct DefaultDeviceContextType { class EigenCudaStreamDevice; class CudnnHolder; +class CudnnWorkspaceHandle { + public: + /*! \brief The lock would not be acquired when constructor calls. + * The lock would be acquired when RunFunc() is called first time. */ + explicit CudnnWorkspaceHandle(CudnnHolder* holder); + + /*! \brief Thread which call RunFunc() would acquire the lock first + * before invoking cudnn functions. */ + void RunFunc(const std::function& cudnn_func, + size_t required_workspace_len); + + /*! \brief User can call this method to acquire the lock manually, + * But it is usually unnecessary, because RunFunc() would + * acquire the lock first before invoking cudnn functions. */ + void BeginCallGuard(); + + /*! \brief User can call this method to release the lock manually, + * But it is usually unnecssary, because the lock would be + * release once the handle is destructed. But it can be used + * to manually release the lock as soon as possible. */ + void EndCallGuard(); + + private: + CudnnHolder* holder_; // not own + std::unique_ptr> guard_; +}; + class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -100,6 +127,15 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; + /*! \brief Return a cudnn workspace handle to call multiple cudnn + * functions without interrupting by other threads. + * Once the first cudnn function is called by the handle, a lock + * would be acquired to prevent other threads from accessing the + * workspace. Once the handle is destructed, the lock would be released. + * CudnnWorkspaceHandle is an RAII object to implement thread-safe + * sequential cudnn function calls. */ + CudnnWorkspaceHandle cudnn_workspace_handle() const; + /*! \brief Run a cudnn function with the workspace provided by * CUDADeviceContext */ void RunCudnnFuncWithWorkspace(const std::function& cudnn_func,