diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index d8bc72e6b2fa38db06cb077ada9d7ec180299e8c..d6b5ad4570c1d8402dedb8596cc75d9eae5a91c7 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1,5 +1,6 @@ cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) +cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) @@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -if(WITH_GPU) +cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) + +if (WITH_GPU) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) endif() @@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -if(WITH_GPU) - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass) -else() - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass) +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass) +if (WITH_GPU) + list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) endif() +cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) + cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index bc19bd36610bf144f163c8ebf582d4afbc6592e3..48f94a1f05614d4b797562ac67cdb9828fd0456e 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Verify that the graph is correct for multi-device executor. AppendPass("multi_devices_check_pass"); + + if (strategy_.remove_unnecessary_lock_) { + AppendPass("modify_op_lock_and_record_event_pass"); + } } private: @@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); USE_PASS(sequential_execution_pass); +USE_PASS(modify_op_lock_and_record_event_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 88459320b0eb6d6c4405bff4c8b13c99aa7edb0d..6c7b54db8f610aa34cd51dcbc13063290cae3ac0 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -73,6 +73,8 @@ struct BuildStrategy { bool fuse_broadcast_op_{false}; + bool remove_unnecessary_lock_{false}; + // User normally doesn't need to call this API. // The PassBuilder allows for more customized insert, remove of passes // from python side. diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index f9bbfe0016ce0ea0d15a83cb532c44518549b8ad..7ad1e40c600c6e70cea822fac777ff20163078e6 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); - this->RunAndRecordEvent([this] { + auto run_func = [this]() { op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); - }); + }; + + if (is_lock_and_record_event_free_) { + run_func(); + } else { + this->RunAndRecordEvent(run_func); + } } bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index e98f1ab148db083ac63a1afd43e334fbfae62539..662a91d6b4dfcfed563fdf2e46c22f83f90b40af 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { const platform::Place &GetPlace() const { return place_; } + void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } + protected: void RunImpl() override; @@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; + bool is_lock_and_record_event_free_{false}; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..169ce3ae7ca497e40d99b1c16633e35e1e4f1009 --- /dev/null +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/op_graph_view.h" + +namespace paddle { +namespace framework { +namespace details { + +static bool IsLockAndRecordEventFreeComputationOpHandle( + ComputationOpHandle *op, const OpGraphView &graph_view) { + if (!platform::is_gpu_place(op->GetPlace())) return false; + for (auto &pending_op : graph_view.PendingOps(op)) { + auto *tmp = dynamic_cast(pending_op); + if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { + return false; + } + } + return true; +} + +std::unique_ptr ModifyOpLockAndRecordEventPass::ApplyImpl( + std::unique_ptr ir_graph) const { + auto &all_ops = ir_graph->Get(kGraphOps); + OpGraphView graph_view(all_ops); + for (auto &op : all_ops) { + auto *compute_op = dynamic_cast(op.get()); + if (compute_op == nullptr) continue; + bool is_lock_and_record_event_free = + IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); + compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free); + if (is_lock_and_record_event_free) { + VLOG(10) << "Set is_lock_and_record_event_free be true in op " + << compute_op->DebugString(); + } + } + return ir_graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(modify_op_lock_and_record_event_pass, + paddle::framework::details::ModifyOpLockAndRecordEventPass); diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..b54e1b318be95e1e0abf6830f8c918895df02718 --- /dev/null +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class ModifyOpLockAndRecordEventPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_graph_view.cc b/paddle/fluid/framework/details/op_graph_view.cc new file mode 100644 index 0000000000000000000000000000000000000000..65dafd376f7c687410270e35f105ff595fe78f59 --- /dev/null +++ b/paddle/fluid/framework/details/op_graph_view.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/op_graph_view.h" +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +OpGraphView::OpGraphView( + const std::vector> &ops) { + Build(ops); +} + +void OpGraphView::Build(const std::vector> &ops) { + for (auto &op : ops) { + preceding_ops_[op.get()]; + pending_ops_[op.get()]; + for (auto &var : op->Outputs()) { + for (auto &pending_op : var->PendingOps()) { + preceding_ops_[pending_op].insert(op.get()); + pending_ops_[op.get()].insert(pending_op); + } + } + } + PADDLE_ENFORCE( + preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(), + "There are duplicate ops in graph."); +} + +size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); } + +std::unordered_set OpGraphView::AllOps() const { + std::unordered_set ret; + for (auto &pair : preceding_ops_) { + ret.insert(pair.first); + } + return ret; +} + +bool OpGraphView::HasOp(OpHandleBase *op) const { + return preceding_ops_.count(op) != 0; +} + +void OpGraphView::EnforceHasOp(OpHandleBase *op) const { + PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView", + op == nullptr ? "nullptr" : op->DebugString()); +} + +const std::unordered_set &OpGraphView::PrecedingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + return preceding_ops_.at(op); +} + +const std::unordered_set &OpGraphView::PendingOps( + OpHandleBase *op) const { + EnforceHasOp(op); + return pending_ops_.at(op); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_graph_view.h b/paddle/fluid/framework/details/op_graph_view.h new file mode 100644 index 0000000000000000000000000000000000000000..398c019be00a6ff5f5b39fdcbe97339341b1685b --- /dev/null +++ b/paddle/fluid/framework/details/op_graph_view.h @@ -0,0 +1,54 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" + +namespace paddle { +namespace framework { +namespace details { + +class OpGraphView { + public: + explicit OpGraphView(const std::vector> &ops); + + size_t OpNumber() const; + + std::unordered_set AllOps() const; + + const std::unordered_set &PrecedingOps( + OpHandleBase *op) const; + + const std::unordered_set &PendingOps(OpHandleBase *op) const; + + bool HasOp(OpHandleBase *op) const; + + private: + void Build(const std::vector> &ops); + void EnforceHasOp(OpHandleBase *op) const; + + std::unordered_map> + preceding_ops_; + std::unordered_map> + pending_ops_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h index fc479a4c4a1e7d5c824d3c202e0cccf743dd52c9..cc4ccfbdfc720284e683a8f3f59a4aa57a3a9eb1 100644 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase { dev_ctx_ = static_cast( platform::DeviceContextPool::Instance().Get(place)); if (IsStreamGarabageCollector()) { - PADDLE_ENFORCE(cudaSetDevice(place.device)); + platform::SetDeviceId(place.device); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } @@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ~ReferenceCountOpHandle() { if (IsStreamGarabageCollector()) { auto gpu_place = boost::get(dev_ctx_->GetPlace()); - PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); + platform::SetDeviceId(gpu_place.device); PADDLE_ENFORCE(cudaEventDestroy(event_)); } } diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 2d1f688d64ece3322e253b0c070264b9eb73d678..0b994ced7f751f056fec076e3dea8d14d0bed991 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { return nullptr; } +static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out, + ir::Graph *graph) { + auto it = std::find_if( + in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) { + return dynamic_cast(var) != nullptr; + }); + + if (it != in->Outputs().end()) { + out->AddInput(*it); + } else { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dep_var); + in->AddOutput(dep_var); + out->AddInput(dep_var); + } +} + std::unique_ptr ReferenceCountPass::ApplyImpl( std::unique_ptr graph) const { auto &ref_cnts = Get(kGlobalReferenceCount); @@ -133,12 +150,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( auto *ref_cnt_handle = new ReferenceCountOpHandle( ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, gcs[place.device].get(), cur_ref_cnts[place.device].get()); - if (next_compute_op->Outputs().empty()) { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - next_compute_op->AddOutput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - } - ref_cnt_handle->AddInput(next_compute_op->Outputs().front()); + AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); } } @@ -160,12 +172,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( auto *ref_cnt_handle = new ReferenceCountOpHandle( ref_cnt_node, compute_op->GetScope(), place, in_var_names, gcs[place.device].get(), cur_ref_cnts[place.device].get()); - if (compute_op->Outputs().empty()) { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - compute_op->AddOutput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - } - ref_cnt_handle->AddInput(compute_op->Outputs().front()); + AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); } diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 4a7a6bcf7154d5680de751e3c933be46fb09fd74..c37032bf090a34077f0f706307c07a0c0fd1185d 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int i = 0; i < groups; i++) { auto cudnn_func = [&](void* cudnn_workspace) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( @@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + i * group_offset_out)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } }; @@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. @@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, input_grad_data + i * group_offset_in)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } // ------------------- cudnn conv backward filter --------------------- @@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } } diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 73831611d01b8c5b8d2d9f7f15634a0094e4a608..f44094ca6b7b7f23f2e7593ad79e4e2a6f0d3070 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { int output_offset = output->numel() / output->dims()[0] / groups; int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int g = 0; g < groups; g++) { auto cudnn_func = [&](void* cudnn_workspace) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( @@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + output_offset * g)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } }; @@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { output_grad->numel() / output_grad->dims()[0] / groups; int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. @@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, input_grad_data + input_offset * g)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } @@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data + filter_offset * g)); }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 2c7f6c9d5f12c03401910e146bba2fe47166e5b4..ff49a1d57fd977a6d6b4502b44e48aad34cde872 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -153,55 +153,31 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -class CudnnHolder { - public: - CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); - } - - cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } +CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) + : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); +} - void RunFunc(const std::function& cudnn_func, - size_t required_workspace_len) { - std::lock_guard lock(mtx_); - if (required_workspace_len > workspace_len_) { - ReallocateWorkspace(required_workspace_len); - } - cudnn_func(workspace_); +CudnnHolder::~CudnnHolder() { + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); + if (workspace_ != nullptr) { + paddle::memory::Free(place_, workspace_); } +} - ~CudnnHolder() { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - if (workspace_ != nullptr) { - paddle::memory::Free(place_, workspace_); - } +void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { + if (required_workspace_len <= workspace_len_) { + return; } - - private: - void ReallocateWorkspace(size_t required_workspace_len) { - if (required_workspace_len <= workspace_len_) { - return; - } - if (workspace_ != nullptr) { - // Maybe someone is using the current workspace - PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); - paddle::memory::Free(place_, workspace_); - } - workspace_ = paddle::memory::Alloc(place_, required_workspace_len); - workspace_len_ = required_workspace_len; + if (workspace_ != nullptr) { + // Maybe someone is using the current workspace + PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); + paddle::memory::Free(place_, workspace_); } - - cudnnHandle_t cudnn_handle_; - void* workspace_; - size_t workspace_len_; - - const cudaStream_t* stream_; // not owned; - const CUDAPlace place_; - - std::mutex mtx_; -}; + workspace_ = paddle::memory::Alloc(place_, required_workspace_len); + workspace_len_ = required_workspace_len; +} CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place), cudnn_holder_(nullptr) { @@ -269,9 +245,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_holder_->cudnn_handle(); } -void CUDADeviceContext::RunCudnnFuncWithWorkspace( - const std::function& cudnn_func, size_t workspace_len) const { - cudnn_holder_->RunFunc(cudnn_func, workspace_len); +CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { + return CudnnWorkspaceHandle(cudnn_holder_.get()); } cudaStream_t CUDADeviceContext::stream() const { return stream_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 0240b9380f3213b2a030061007e04abe1d73c6e3..df248f9bb15591d5015ad01278797ec7e31ef9d1 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -73,7 +73,60 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; -class CudnnHolder; +class CudnnHolder { + public: + CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place); + ~CudnnHolder(); + cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } + + private: + friend class CudnnWorkspaceHandle; + void ReallocateWorkspace(size_t required_workspace_len); + + template + void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) { + if (required_workspace_len > workspace_len_) { + ReallocateWorkspace(required_workspace_len); + } + cudnn_func(workspace_); + } + + std::mutex& Mutex() { return mtx_; } + + cudnnHandle_t cudnn_handle_; + void* workspace_; + size_t workspace_len_; + + const cudaStream_t* stream_; // not owned; + const CUDAPlace place_; + + std::mutex mtx_; +}; + +class CudnnWorkspaceHandle { + public: + /*! \brief The lock would not be acquired when constructor calls. + * The lock would be acquired when RunFunc() is called first time. */ + inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {} + + /*! \brief Thread which call RunFunc() would acquire the lock first + * before invoking cudnn functions. */ + template + inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) { + if (!guard_) { + guard_.reset(new std::lock_guard(holder_->Mutex())); + } + holder_->RunFuncImpl(std::forward(cudnn_func), + required_workspace_len); + } + + CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default; + CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete; + + private: + CudnnHolder* holder_; // not own + std::unique_ptr> guard_; +}; class CUDADeviceContext : public DeviceContext { public: @@ -101,10 +154,14 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; - /*! \brief Run a cudnn function with the workspace provided by - * CUDADeviceContext */ - void RunCudnnFuncWithWorkspace(const std::function& cudnn_func, - size_t workspace_len) const; + /*! \brief Return a cudnn workspace handle to call multiple cudnn + * functions without interrupting by other threads. + * Once the first cudnn function is called by the handle, a lock + * would be acquired to prevent other threads from accessing the + * workspace. Once the handle is destructed, the lock would be released. + * CudnnWorkspaceHandle is an RAII object to implement thread-safe + * sequential cudnn function calls. */ + CudnnWorkspaceHandle cudnn_workspace_handle() const; /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7c7b14df6618bd636f3636612486884b573309fb..fc821e04a0baf9278295da18ee5a69afcf2c4605 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) // FIXME(chengudo): enable_data_balance seems not important - .def_property("enable_sequential_execution", - [](const BuildStrategy &self) { - return self.enable_sequential_execution_; - }, - [](BuildStrategy &self, bool b) { - self.enable_sequential_execution_ = b; - }) + .def_property( + "enable_sequential_execution", + [](const BuildStrategy &self) { + return self.enable_sequential_execution_; + }, + [](BuildStrategy &self, bool b) { + self.enable_sequential_execution_ = b; + }, + R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC") + .def_property( + "remove_unnecessary_lock", + [](const BuildStrategy &self) { + return self.remove_unnecessary_lock_; + }, + [](BuildStrategy &self, bool b) { + self.remove_unnecessary_lock_ = b; + }, + R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC") .def_property( "fuse_elewise_add_act_ops", [](const BuildStrategy &self) { diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index a3fe5e0a0591c8da787e3c2fdb030f3912548316..86f861674c26fe61e624103c2a0d70f816a1aebc 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -18,6 +18,7 @@ import multiprocessing import os import unittest import paddle.fluid as fluid +import paddle.fluid.core as core import time import numpy as np import math @@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase): if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.enable_sequential_execution = enable_sequential_execution + if use_cuda and core.is_compiled_with_cuda(): + build_strategy.remove_unnecessary_lock = True if use_parallel_executor: exe = fluid.ParallelExecutor(