未验证 提交 8ac2242b 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #14075 from sneaxiy/remove_some_locks_in_pe

Remove some locks in ParallelExecutor
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
......@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
......@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
}
private:
......@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
......@@ -73,6 +73,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
......
......@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] {
auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
});
};
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
......
......@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected:
void RunImpl() override;
......@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
} // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
}
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
pending_ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
......@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
......
......@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return nullptr;
}
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
......@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
}
}
......@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(compute_op->Outputs().front());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
}
......
......@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
......@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
};
......@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
......@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
// ------------------- cudnn conv backward filter ---------------------
......@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
}
......
......@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
......@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
};
......@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
......@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
......@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
}
......
......@@ -153,55 +153,31 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_;
};
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_);
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
CudnnHolder::~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
private:
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
......@@ -269,9 +245,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle();
}
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder_.get());
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -73,7 +73,60 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
class CudnnHolder;
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
~CudnnHolder();
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
private:
friend class CudnnWorkspaceHandle;
void ReallocateWorkspace(size_t required_workspace_len);
template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
class CudnnWorkspaceHandle {
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
}
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
private:
CudnnHolder* holder_; // not own
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
class CUDADeviceContext : public DeviceContext {
public:
......@@ -101,10 +154,14 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
size_t workspace_len) const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
......
......@@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property("enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequential_execution_ = b;
})
.def_property(
"enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
.def_property(
"remove_unnecessary_lock",
[](const BuildStrategy &self) {
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
......
......@@ -18,6 +18,7 @@ import multiprocessing
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import time
import numpy as np
import math
......@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册