diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 57573b37c3852c46a1e06ba7d6f57d8a56dad18e..d6b5ad4570c1d8402dedb8596cc75d9eae5a91c7 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) -cc_library(op_handle_graph SRCS op_handle_graph.cc DEPS op_handle_base) +cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) @@ -31,9 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_handle_graph multi_devices_helper) +cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) -if(WITH_GPU) +if (WITH_GPU) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) endif() @@ -43,7 +43,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto modify_op_lock_and_record_event_pass sequential_execution_pass) +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass) if (WITH_GPU) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) endif() diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index bc19bd36610bf144f163c8ebf582d4afbc6592e3..48f94a1f05614d4b797562ac67cdb9828fd0456e 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Verify that the graph is correct for multi-device executor. AppendPass("multi_devices_check_pass"); + + if (strategy_.remove_unnecessary_lock_) { + AppendPass("modify_op_lock_and_record_event_pass"); + } } private: @@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); USE_PASS(sequential_execution_pass); +USE_PASS(modify_op_lock_and_record_event_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 88459320b0eb6d6c4405bff4c8b13c99aa7edb0d..6c7b54db8f610aa34cd51dcbc13063290cae3ac0 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -73,6 +73,8 @@ struct BuildStrategy { bool fuse_broadcast_op_{false}; + bool remove_unnecessary_lock_{false}; + // User normally doesn't need to call this API. // The PassBuilder allows for more customized insert, remove of passes // from python side. diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 7beb8c8de9fc49aebc66ca44de8736240aabbc30..7ad1e40c600c6e70cea822fac777ff20163078e6 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -20,13 +20,11 @@ namespace paddle { namespace framework { namespace details { ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, - platform::Place place, - size_t scope_idx) + platform::Place place) : OpHandleBase(node), op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), - place_(place), - scope_idx_(scope_idx) {} + place_(place) {} void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 2d877f90583da30cc7cc6db31d565e99976ae68b..662a91d6b4dfcfed563fdf2e46c22f83f90b40af 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,8 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, - size_t scope_idx); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); std::string Name() const override; @@ -37,12 +36,6 @@ struct ComputationOpHandle : public OpHandleBase { const platform::Place &GetPlace() const { return place_; } - size_t GetScopeIdx() const { return scope_idx_; } - - OperatorBase &GetOp() { return *op_; } - - const OperatorBase &GetOp() const { return *op_; } - void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } protected: @@ -54,7 +47,6 @@ struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; - size_t scope_idx_; bool is_lock_and_record_event_free_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc index ed07d84fd643dd1aa711b227a0c4186985dd078b..169ce3ae7ca497e40d99b1c16633e35e1e4f1009 100644 --- a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc @@ -15,20 +15,17 @@ #include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" -#include "paddle/fluid/framework/details/op_handle_graph.h" +#include "paddle/fluid/framework/details/op_graph_view.h" namespace paddle { namespace framework { namespace details { -static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) { - return dynamic_cast(op); -} - static bool IsLockAndRecordEventFreeComputationOpHandle( - ComputationOpHandle *op, const OpHandleGraph &graph) { - for (auto &pending_op : graph.PendingOps(op)) { - auto *tmp = ConvertToComputationOpHandle(pending_op); + ComputationOpHandle *op, const OpGraphView &graph_view) { + if (!platform::is_gpu_place(op->GetPlace())) return false; + for (auto &pending_op : graph_view.PendingOps(op)) { + auto *tmp = dynamic_cast(pending_op); if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { return false; } @@ -39,12 +36,12 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( std::unique_ptr ModifyOpLockAndRecordEventPass::ApplyImpl( std::unique_ptr ir_graph) const { auto &all_ops = ir_graph->Get(kGraphOps); - OpHandleGraph graph(all_ops); + OpGraphView graph_view(all_ops); for (auto &op : all_ops) { - auto *compute_op = ConvertToComputationOpHandle(op.get()); + auto *compute_op = dynamic_cast(op.get()); if (compute_op == nullptr) continue; bool is_lock_and_record_event_free = - IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph); + IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free); if (is_lock_and_record_event_free) { VLOG(10) << "Set is_lock_and_record_event_free be true in op " diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 7154385a4122022ffde5f47623c1c2471be39dc1..f3819887a196a7c8bf35897467bb9d68b428094e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -556,7 +556,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), - local_scopes_[dev_id], places_[dev_id], dev_id)); + local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); } @@ -672,8 +672,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get(kGraphOps).emplace_back(new ComputationOpHandle( - result->CreateOpNode(node->Op()), s, p, scope_idx)); + result->Get(kGraphOps).emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } } diff --git a/paddle/fluid/framework/details/op_graph_view.cc b/paddle/fluid/framework/details/op_graph_view.cc new file mode 100644 index 0000000000000000000000000000000000000000..65dafd376f7c687410270e35f105ff595fe78f59 --- /dev/null +++ b/paddle/fluid/framework/details/op_graph_view.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/op_graph_view.h" +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +OpGraphView::OpGraphView( + const std::vector> &ops) { + Build(ops); +} + +void OpGraphView::Build(const std::vector> &ops) { + for (auto &op : ops) { + preceding_ops_[op.get()]; + pending_ops_[op.get()]; + for (auto &var : op->Outputs()) { + for (auto &pending_op : var->PendingOps()) { + preceding_ops_[pending_op].insert(op.get()); + pending_ops_[op.get()].insert(pending_op); + } + } + } + PADDLE_ENFORCE( + preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(), + "There are duplicate ops in graph."); +} + +size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); } + +std::unordered_set OpGraphView::AllOps() const { + std::unordered_set ret; + for (auto &pair : preceding_ops_) { + ret.insert(pair.first); + } + return ret; +} + +bool OpGraphView::HasOp(OpHandleBase *op) const { + return preceding_ops_.count(op) != 0; +} + +void OpGraphView::EnforceHasOp(OpHandleBase *op) const { + PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView", + op == nullptr ? "nullptr" : op->DebugString()); +} + +const std::unordered_set &OpGraphView::PrecedingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + return preceding_ops_.at(op); +} + +const std::unordered_set &OpGraphView::PendingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + return pending_ops_.at(op); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_graph.h b/paddle/fluid/framework/details/op_graph_view.h similarity index 51% rename from paddle/fluid/framework/details/op_handle_graph.h rename to paddle/fluid/framework/details/op_graph_view.h index 803edce048e7aee7a70021ee75fcd66576bf4a42..398c019be00a6ff5f5b39fdcbe97339341b1685b 100644 --- a/paddle/fluid/framework/details/op_handle_graph.h +++ b/paddle/fluid/framework/details/op_graph_view.h @@ -24,11 +24,9 @@ namespace paddle { namespace framework { namespace details { -class OpHandleGraph { +class OpGraphView { public: - enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 }; - - explicit OpHandleGraph(const std::vector> &ops); + explicit OpGraphView(const std::vector> &ops); size_t OpNumber() const; @@ -39,42 +37,11 @@ class OpHandleGraph { const std::unordered_set &PendingOps(OpHandleBase *op) const; - std::vector> AllPrecedingOps( - OpHandleBase *op) const; - - std::vector> AllPendingOps( - OpHandleBase *op) const; - bool HasOp(OpHandleBase *op) const; - Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const; - - bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const; - - bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const; - - bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const; - - bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const; - - bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const; - - bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const; - - OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const; - - // Find an operator that is after op and before op1, op2 - OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1, - OpHandleBase *op2) const; - - std::unordered_set NoPendingOpSet() const; - - std::unordered_set NoPrecedingOpSet() const; - private: - void BuildGraph(const std::vector> &ops); + void Build(const std::vector> &ops); void EnforceHasOp(OpHandleBase *op) const; - bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const; std::unordered_map> preceding_ops_; diff --git a/paddle/fluid/framework/details/op_handle_graph.cc b/paddle/fluid/framework/details/op_handle_graph.cc deleted file mode 100644 index 0e70305cec04c43b05481f69401f9957b74c2f4b..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/op_handle_graph.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/op_handle_graph.h" -#include -#include - -namespace paddle { -namespace framework { -namespace details { - -OpHandleGraph::OpHandleGraph( - const std::vector> &ops) { - BuildGraph(ops); -} - -void OpHandleGraph::BuildGraph( - const std::vector> &ops) { - for (auto &op : ops) { - preceding_ops_[op.get()]; - pending_ops_[op.get()]; - for (auto &var : op->Outputs()) { - for (auto &pending_op : var->PendingOps()) { - preceding_ops_[pending_op].insert(op.get()); - pending_ops_[op.get()].insert(pending_op); - } - } - } - PADDLE_ENFORCE( - preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(), - "There are duplicate ops in graph."); -} - -size_t OpHandleGraph::OpNumber() const { return preceding_ops_.size(); } - -std::unordered_set OpHandleGraph::AllOps() const { - std::unordered_set ret; - for (auto &pair : preceding_ops_) { - ret.insert(pair.first); - } - return ret; -} - -bool OpHandleGraph::HasOp(OpHandleBase *op) const { - return preceding_ops_.count(op) != 0; -} - -void OpHandleGraph::EnforceHasOp(OpHandleBase *op) const { - PADDLE_ENFORCE(HasOp(op), "Cannot found op %s in OpHandleGraph", - op == nullptr ? "nullptr" : op->DebugString()); -} - -const std::unordered_set &OpHandleGraph::PrecedingOps( - OpHandleBase *op) const { - EnforceHasOp(op); - return preceding_ops_.at(op); -} - -const std::unordered_set &OpHandleGraph::PendingOps( - OpHandleBase *op) const { - EnforceHasOp(op); - return pending_ops_.at(op); -} - -std::vector> OpHandleGraph::AllPrecedingOps( - OpHandleBase *op) const { - EnforceHasOp(op); - std::queue queue[2]; - int cur = 0; - std::unordered_set visited_ops; - std::vector> ret; - for (auto &tmp : preceding_ops_.at(op)) { - queue[cur].push(tmp); - visited_ops.insert(tmp); - } - - while (!queue[cur].empty()) { - std::unordered_set cur_level_ops; - auto *tmp = queue[cur].front(); - queue[cur].pop(); - for (auto &preceding_op : preceding_ops_.at(tmp)) { - if (visited_ops.count(preceding_op)) { - continue; - } else { - queue[1 - cur].push(preceding_op); - cur_level_ops.insert(preceding_op); - visited_ops.insert(preceding_op); - } - } - if (!cur_level_ops.empty()) { - ret.emplace_back(std::move(cur_level_ops)); - } - cur = 1 - cur; - } - return ret; -} - -std::vector> OpHandleGraph::AllPendingOps( - OpHandleBase *op) const { - EnforceHasOp(op); - std::queue queue[2]; - int cur = 0; - std::unordered_set visited_ops; - std::vector> ret; - for (auto &tmp : preceding_ops_.at(op)) { - queue[cur].push(tmp); - visited_ops.insert(tmp); - } - - while (!queue[cur].empty()) { - std::unordered_set cur_level_ops; - auto *tmp = queue[cur].front(); - queue[cur].pop(); - for (auto &next_op : pending_ops_.at(tmp)) { - if (visited_ops.count(next_op)) { - continue; - } else { - queue[1 - cur].push(next_op); - cur_level_ops.insert(next_op); - visited_ops.insert(next_op); - } - } - if (!cur_level_ops.empty()) { - ret.emplace_back(std::move(cur_level_ops)); - } - cur = 1 - cur; - } - return ret; -} - -OpHandleGraph::Relation OpHandleGraph::RelationBetween( - OpHandleBase *op1, OpHandleBase *op2) const { - EnforceHasOp(op1); - EnforceHasOp(op2); - if (op1 == op2) { - return kSame; - } else if (IsBeforeOrSameImpl(op1, op2)) { - return kBefore; - } else if (IsBeforeOrSameImpl(op2, op1)) { - return kAfter; - } else { - return kNoDeps; - } -} - -bool OpHandleGraph::IsSame(OpHandleBase *op1, OpHandleBase *op2) const { - EnforceHasOp(op1); - EnforceHasOp(op2); - return op1 == op2; -} - -bool OpHandleGraph::IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const { - EnforceHasOp(op1); - EnforceHasOp(op2); - return IsBeforeOrSameImpl(op1, op2); -} - -bool OpHandleGraph::IsBefore(OpHandleBase *op1, OpHandleBase *op2) const { - EnforceHasOp(op1); - EnforceHasOp(op2); - return op1 != op2 && IsBeforeOrSameImpl(op1, op2); -} - -bool OpHandleGraph::IsBeforeOrSameImpl(OpHandleBase *op1, - OpHandleBase *op2) const { - std::queue queue; - // BFS - queue.push(op1); - do { - auto *op = queue.front(); - queue.pop(); - if (op == op2) return true; - for (auto &pending_op : pending_ops_.at(op)) { - queue.push(pending_op); - } - } while (!queue.empty()); - return false; -} - -bool OpHandleGraph::IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const { - EnforceHasOp(op1); - EnforceHasOp(op2); - return IsBeforeOrSameImpl(op2, op1); -} - -bool OpHandleGraph::IsAfter(OpHandleBase *op1, OpHandleBase *op2) const { - return IsBefore(op2, op1); -} - -bool OpHandleGraph::IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const { - return RelationBetween(op1, op2) == kNoDeps; -} - -std::unordered_set OpHandleGraph::NoPendingOpSet() const { - std::unordered_set ret; - for (auto &pair : pending_ops_) { - if (pair.second.empty()) ret.insert(pair.first); - } - return ret; -} - -std::unordered_set OpHandleGraph::NoPrecedingOpSet() const { - std::unordered_set ret; - for (auto &pair : preceding_ops_) { - if (pair.second.empty()) ret.insert(pair.first); - } - return ret; -} - -OpHandleBase *OpHandleGraph::NearestCommonParent(OpHandleBase *op1, - OpHandleBase *op2) const { - EnforceHasOp(op1); - EnforceHasOp(op2); - // FIXME(zjl): A brute-force O(2*n) algorithm here - // First, BFS all preceding_ops of op1 and record them in set S - // Second, BFS all preceding_ops of op2 and found whether it is in set S - std::unordered_set all_preceding_ops; - std::queue queue; - queue.push(op1); - do { - auto *op = queue.front(); - queue.pop(); - all_preceding_ops.insert(op); - for (auto &preceding_op : preceding_ops_.at(op)) { - queue.push(preceding_op); - } - } while (!queue.empty()); - - queue.push(op2); - do { - auto *op = queue.front(); - queue.pop(); - if (all_preceding_ops.count(op)) return op; - for (auto &preceding_op : preceding_ops_.at(op)) { - queue.push(preceding_op); - } - } while (!queue.empty()); - return nullptr; -} - -OpHandleBase *OpHandleGraph::NearestCommonParentAfter(OpHandleBase *op, - OpHandleBase *op1, - OpHandleBase *op2) const { - EnforceHasOp(op); - EnforceHasOp(op1); - EnforceHasOp(op2); - std::unordered_map all_preceding_ops; - int max_depth = -1; - std::queue> queue; - queue.push(std::make_pair(op1, 0)); - do { - auto tmp = queue.front(); - queue.pop(); - all_preceding_ops.insert(tmp); - if (tmp.first == op1) { - max_depth = tmp.second; - break; - } - for (auto &preceding_op : preceding_ops_.at(tmp.first)) { - queue.push(std::make_pair(preceding_op, tmp.second + 1)); - } - } while (!queue.empty()); - - if (max_depth == -1) { - return nullptr; - } - - std::queue queue2; - queue2.push(op2); - do { - auto *tmp = queue2.front(); - queue2.pop(); - if (all_preceding_ops.count(tmp) && - (tmp == op || all_preceding_ops[tmp] < max_depth)) { - return tmp; - } - } while (!queue2.empty()); - return nullptr; -} - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 47f914e98f1f1de92b2aa1e90658022274f7b958..a45b9ec7a20ac3629d182f009b735d4d82fb5dc2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -118,10 +118,6 @@ ParallelExecutor::ParallelExecutor( main_program, member_->places_, loss_var_name, params, member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); - graph = ir::PassRegistry::Instance() - .Get("modify_op_lock_and_record_event_pass") - ->Apply(std::move(graph)); - auto max_memory_size = GetEagerDeletionThreshold(); if (max_memory_size >= 0) { for (auto &place : member_->places_) { @@ -149,10 +145,6 @@ ParallelExecutor::ParallelExecutor( std::unique_ptr graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, params, member_->local_scopes_, member_->use_cuda_); - - graph = ir::PassRegistry::Instance() - .Get("modify_op_lock_and_record_event_pass") - ->Apply(std::move(graph)); #endif // Step 3. Create vars in each scope. Passes may also create new vars. @@ -331,8 +323,6 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle - -USE_PASS(modify_op_lock_and_record_event_pass); #ifdef PADDLE_WITH_CUDA USE_PASS(reference_count_pass); #endif diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index ae18c4310bc8aac9b1f6f0087ccfc999264d2aac..7fc73d23fc3b01ab410a3375649b830678b122b0 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -153,83 +153,32 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -class CudnnHolder { - public: - CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); - } - - cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } - - void RunFunc(const std::function& cudnn_func, - size_t required_workspace_len) { - std::lock_guard lock(mtx_); - RunFuncImpl(cudnn_func, required_workspace_len); - } - - ~CudnnHolder() { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - if (workspace_ != nullptr) { - paddle::memory::Free(place_, workspace_); - } - } - - private: - std::mutex& Mutex() { return mtx_; } +CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) + : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); +} - void RunFuncImpl(const std::function& cudnn_func, - size_t required_workspace_len) { - if (required_workspace_len > workspace_len_) { - ReallocateWorkspace(required_workspace_len); - } - cudnn_func(workspace_); +CudnnHolder::~CudnnHolder() { + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); + if (workspace_ != nullptr) { + paddle::memory::Free(place_, workspace_); } - - void ReallocateWorkspace(size_t required_workspace_len) { - if (required_workspace_len <= workspace_len_) { - return; - } - if (workspace_ != nullptr) { - // Maybe someone is using the current workspace - PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); - paddle::memory::Free(place_, workspace_); - } - workspace_ = paddle::memory::Alloc(place_, required_workspace_len); - workspace_len_ = required_workspace_len; - } - - friend class CudnnWorkspaceHandle; - - cudnnHandle_t cudnn_handle_; - void* workspace_; - size_t workspace_len_; - - const cudaStream_t* stream_; // not owned; - const CUDAPlace place_; - - std::mutex mtx_; -}; - -CudnnWorkspaceHandle::CudnnWorkspaceHandle(CudnnHolder* holder) - : holder_(holder) {} - -void CudnnWorkspaceHandle::RunFunc(const std::function& cudnn_func, - size_t required_workspace_len) { - // defer lock when the function is invoked first time - BeginCallGuard(); - holder_->RunFuncImpl(cudnn_func, required_workspace_len); } -void CudnnWorkspaceHandle::BeginCallGuard() { - if (!guard_) { - guard_.reset(new std::lock_guard(holder_->Mutex())); +void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { + if (required_workspace_len <= workspace_len_) { + return; + } + if (workspace_ != nullptr) { + // Maybe someone is using the current workspace + PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); + paddle::memory::Free(place_, workspace_); } + workspace_ = paddle::memory::Alloc(place_, required_workspace_len); + workspace_len_ = required_workspace_len; } -void CudnnWorkspaceHandle::EndCallGuard() { guard_.reset(); } - CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place), cudnn_holder_(nullptr) { SetDeviceId(place_.device); @@ -300,11 +249,6 @@ CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(cudnn_holder_.get()); } -void CUDADeviceContext::RunCudnnFuncWithWorkspace( - const std::function& cudnn_func, size_t workspace_len) const { - cudnn_holder_->RunFunc(cudnn_func, workspace_len); -} - cudaStream_t CUDADeviceContext::stream() const { return stream_; } CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index b54cb61064ccd4d930eea5205045ed54661ebb90..df248f9bb15591d5015ad01278797ec7e31ef9d1 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -73,29 +73,55 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; -class CudnnHolder; +class CudnnHolder { + public: + CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place); + ~CudnnHolder(); + cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } + + private: + friend class CudnnWorkspaceHandle; + void ReallocateWorkspace(size_t required_workspace_len); + + template + void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) { + if (required_workspace_len > workspace_len_) { + ReallocateWorkspace(required_workspace_len); + } + cudnn_func(workspace_); + } + + std::mutex& Mutex() { return mtx_; } + + cudnnHandle_t cudnn_handle_; + void* workspace_; + size_t workspace_len_; + + const cudaStream_t* stream_; // not owned; + const CUDAPlace place_; + + std::mutex mtx_; +}; class CudnnWorkspaceHandle { public: /*! \brief The lock would not be acquired when constructor calls. * The lock would be acquired when RunFunc() is called first time. */ - explicit CudnnWorkspaceHandle(CudnnHolder* holder); + inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {} /*! \brief Thread which call RunFunc() would acquire the lock first * before invoking cudnn functions. */ - void RunFunc(const std::function& cudnn_func, - size_t required_workspace_len); - - /*! \brief User can call this method to acquire the lock manually, - * But it is usually unnecessary, because RunFunc() would - * acquire the lock first before invoking cudnn functions. */ - void BeginCallGuard(); + template + inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) { + if (!guard_) { + guard_.reset(new std::lock_guard(holder_->Mutex())); + } + holder_->RunFuncImpl(std::forward(cudnn_func), + required_workspace_len); + } - /*! \brief User can call this method to release the lock manually, - * But it is usually unnecssary, because the lock would be - * release once the handle is destructed. But it can be used - * to manually release the lock as soon as possible. */ - void EndCallGuard(); + CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default; + CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete; private: CudnnHolder* holder_; // not own @@ -137,11 +163,6 @@ class CUDADeviceContext : public DeviceContext { * sequential cudnn function calls. */ CudnnWorkspaceHandle cudnn_workspace_handle() const; - /*! \brief Run a cudnn function with the workspace provided by - * CUDADeviceContext */ - void RunCudnnFuncWithWorkspace(const std::function& cudnn_func, - size_t workspace_len) const; - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7c7b14df6618bd636f3636612486884b573309fb..fc821e04a0baf9278295da18ee5a69afcf2c4605 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) // FIXME(chengudo): enable_data_balance seems not important - .def_property("enable_sequential_execution", - [](const BuildStrategy &self) { - return self.enable_sequential_execution_; - }, - [](BuildStrategy &self, bool b) { - self.enable_sequential_execution_ = b; - }) + .def_property( + "enable_sequential_execution", + [](const BuildStrategy &self) { + return self.enable_sequential_execution_; + }, + [](BuildStrategy &self, bool b) { + self.enable_sequential_execution_ = b; + }, + R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC") + .def_property( + "remove_unnecessary_lock", + [](const BuildStrategy &self) { + return self.remove_unnecessary_lock_; + }, + [](BuildStrategy &self, bool b) { + self.remove_unnecessary_lock_ = b; + }, + R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC") .def_property( "fuse_elewise_add_act_ops", [](const BuildStrategy &self) { diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index a3fe5e0a0591c8da787e3c2fdb030f3912548316..86f861674c26fe61e624103c2a0d70f816a1aebc 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -18,6 +18,7 @@ import multiprocessing import os import unittest import paddle.fluid as fluid +import paddle.fluid.core as core import time import numpy as np import math @@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase): if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.enable_sequential_execution = enable_sequential_execution + if use_cuda and core.is_compiled_with_cuda(): + build_strategy.remove_unnecessary_lock = True if use_parallel_executor: exe = fluid.ParallelExecutor(