提交 faac8a76 编写于 作者: S sneaxiy

remove unnecessary codes

test=develop
上级 7ff320f8
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_handle_graph SRCS op_handle_graph.cc DEPS op_handle_base)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
......@@ -31,9 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_handle_graph multi_devices_helper)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if(WITH_GPU)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
......@@ -43,7 +43,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto modify_op_lock_and_record_event_pass sequential_execution_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
......
......@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
}
private:
......@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
......@@ -73,6 +73,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
......
......@@ -20,13 +20,11 @@ namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place,
size_t scope_idx)
platform::Place place)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place),
scope_idx_(scope_idx) {}
place_(place) {}
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
......
......@@ -28,8 +28,7 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
std::string Name() const override;
......@@ -37,12 +36,6 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; }
size_t GetScopeIdx() const { return scope_idx_; }
OperatorBase &GetOp() { return *op_; }
const OperatorBase &GetOp() const { return *op_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected:
......@@ -54,7 +47,6 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
size_t scope_idx_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
......
......@@ -15,20 +15,17 @@
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_graph.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) {
return dynamic_cast<ComputationOpHandle *>(op);
}
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpHandleGraph &graph) {
for (auto &pending_op : graph.PendingOps(op)) {
auto *tmp = ConvertToComputationOpHandle(pending_op);
ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
......@@ -39,12 +36,12 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
OpHandleGraph graph(all_ops);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = ConvertToComputationOpHandle(op.get());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph);
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
......
......@@ -556,7 +556,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id));
local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id);
}
......@@ -672,8 +672,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
result->CreateOpNode(node->Op()), s, p, scope_idx));
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx);
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -24,11 +24,9 @@ namespace paddle {
namespace framework {
namespace details {
class OpHandleGraph {
class OpGraphView {
public:
enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 };
explicit OpHandleGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const;
......@@ -39,42 +37,11 @@ class OpHandleGraph {
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
std::vector<std::unordered_set<OpHandleBase *>> AllPrecedingOps(
OpHandleBase *op) const;
std::vector<std::unordered_set<OpHandleBase *>> AllPendingOps(
OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const;
OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const;
// Find an operator that is after op and before op1, op2
OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1,
OpHandleBase *op2) const;
std::unordered_set<OpHandleBase *> NoPendingOpSet() const;
std::unordered_set<OpHandleBase *> NoPrecedingOpSet() const;
private:
void BuildGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void EnforceHasOp(OpHandleBase *op) const;
bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_handle_graph.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpHandleGraph::OpHandleGraph(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
BuildGraph(ops);
}
void OpHandleGraph::BuildGraph(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpHandleGraph::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpHandleGraph::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpHandleGraph::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpHandleGraph::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot found op %s in OpHandleGraph",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpHandleGraph::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpHandleGraph::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
std::vector<std::unordered_set<OpHandleBase *>> OpHandleGraph::AllPrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
std::queue<OpHandleBase *> queue[2];
int cur = 0;
std::unordered_set<OpHandleBase *> visited_ops;
std::vector<std::unordered_set<OpHandleBase *>> ret;
for (auto &tmp : preceding_ops_.at(op)) {
queue[cur].push(tmp);
visited_ops.insert(tmp);
}
while (!queue[cur].empty()) {
std::unordered_set<OpHandleBase *> cur_level_ops;
auto *tmp = queue[cur].front();
queue[cur].pop();
for (auto &preceding_op : preceding_ops_.at(tmp)) {
if (visited_ops.count(preceding_op)) {
continue;
} else {
queue[1 - cur].push(preceding_op);
cur_level_ops.insert(preceding_op);
visited_ops.insert(preceding_op);
}
}
if (!cur_level_ops.empty()) {
ret.emplace_back(std::move(cur_level_ops));
}
cur = 1 - cur;
}
return ret;
}
std::vector<std::unordered_set<OpHandleBase *>> OpHandleGraph::AllPendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
std::queue<OpHandleBase *> queue[2];
int cur = 0;
std::unordered_set<OpHandleBase *> visited_ops;
std::vector<std::unordered_set<OpHandleBase *>> ret;
for (auto &tmp : preceding_ops_.at(op)) {
queue[cur].push(tmp);
visited_ops.insert(tmp);
}
while (!queue[cur].empty()) {
std::unordered_set<OpHandleBase *> cur_level_ops;
auto *tmp = queue[cur].front();
queue[cur].pop();
for (auto &next_op : pending_ops_.at(tmp)) {
if (visited_ops.count(next_op)) {
continue;
} else {
queue[1 - cur].push(next_op);
cur_level_ops.insert(next_op);
visited_ops.insert(next_op);
}
}
if (!cur_level_ops.empty()) {
ret.emplace_back(std::move(cur_level_ops));
}
cur = 1 - cur;
}
return ret;
}
OpHandleGraph::Relation OpHandleGraph::RelationBetween(
OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
if (op1 == op2) {
return kSame;
} else if (IsBeforeOrSameImpl(op1, op2)) {
return kBefore;
} else if (IsBeforeOrSameImpl(op2, op1)) {
return kAfter;
} else {
return kNoDeps;
}
}
bool OpHandleGraph::IsSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return op1 == op2;
}
bool OpHandleGraph::IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return IsBeforeOrSameImpl(op1, op2);
}
bool OpHandleGraph::IsBefore(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return op1 != op2 && IsBeforeOrSameImpl(op1, op2);
}
bool OpHandleGraph::IsBeforeOrSameImpl(OpHandleBase *op1,
OpHandleBase *op2) const {
std::queue<OpHandleBase *> queue;
// BFS
queue.push(op1);
do {
auto *op = queue.front();
queue.pop();
if (op == op2) return true;
for (auto &pending_op : pending_ops_.at(op)) {
queue.push(pending_op);
}
} while (!queue.empty());
return false;
}
bool OpHandleGraph::IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return IsBeforeOrSameImpl(op2, op1);
}
bool OpHandleGraph::IsAfter(OpHandleBase *op1, OpHandleBase *op2) const {
return IsBefore(op2, op1);
}
bool OpHandleGraph::IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const {
return RelationBetween(op1, op2) == kNoDeps;
}
std::unordered_set<OpHandleBase *> OpHandleGraph::NoPendingOpSet() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : pending_ops_) {
if (pair.second.empty()) ret.insert(pair.first);
}
return ret;
}
std::unordered_set<OpHandleBase *> OpHandleGraph::NoPrecedingOpSet() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
if (pair.second.empty()) ret.insert(pair.first);
}
return ret;
}
OpHandleBase *OpHandleGraph::NearestCommonParent(OpHandleBase *op1,
OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
// FIXME(zjl): A brute-force O(2*n) algorithm here
// First, BFS all preceding_ops of op1 and record them in set S
// Second, BFS all preceding_ops of op2 and found whether it is in set S
std::unordered_set<OpHandleBase *> all_preceding_ops;
std::queue<OpHandleBase *> queue;
queue.push(op1);
do {
auto *op = queue.front();
queue.pop();
all_preceding_ops.insert(op);
for (auto &preceding_op : preceding_ops_.at(op)) {
queue.push(preceding_op);
}
} while (!queue.empty());
queue.push(op2);
do {
auto *op = queue.front();
queue.pop();
if (all_preceding_ops.count(op)) return op;
for (auto &preceding_op : preceding_ops_.at(op)) {
queue.push(preceding_op);
}
} while (!queue.empty());
return nullptr;
}
OpHandleBase *OpHandleGraph::NearestCommonParentAfter(OpHandleBase *op,
OpHandleBase *op1,
OpHandleBase *op2) const {
EnforceHasOp(op);
EnforceHasOp(op1);
EnforceHasOp(op2);
std::unordered_map<OpHandleBase *, int> all_preceding_ops;
int max_depth = -1;
std::queue<std::pair<OpHandleBase *, int>> queue;
queue.push(std::make_pair(op1, 0));
do {
auto tmp = queue.front();
queue.pop();
all_preceding_ops.insert(tmp);
if (tmp.first == op1) {
max_depth = tmp.second;
break;
}
for (auto &preceding_op : preceding_ops_.at(tmp.first)) {
queue.push(std::make_pair(preceding_op, tmp.second + 1));
}
} while (!queue.empty());
if (max_depth == -1) {
return nullptr;
}
std::queue<OpHandleBase *> queue2;
queue2.push(op2);
do {
auto *tmp = queue2.front();
queue2.pop();
if (all_preceding_ops.count(tmp) &&
(tmp == op || all_preceding_ops[tmp] < max_depth)) {
return tmp;
}
} while (!queue2.empty());
return nullptr;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -118,10 +118,6 @@ ParallelExecutor::ParallelExecutor(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
graph = ir::PassRegistry::Instance()
.Get("modify_op_lock_and_record_event_pass")
->Apply(std::move(graph));
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (auto &place : member_->places_) {
......@@ -149,10 +145,6 @@ ParallelExecutor::ParallelExecutor(
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_);
graph = ir::PassRegistry::Instance()
.Get("modify_op_lock_and_record_event_pass")
->Apply(std::move(graph));
#endif
// Step 3. Create vars in each scope. Passes may also create new vars.
......@@ -331,8 +323,6 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework
} // namespace paddle
USE_PASS(modify_op_lock_and_record_event_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
......@@ -153,83 +153,32 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_;
};
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_);
RunFuncImpl(cudnn_func, required_workspace_len);
}
~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
private:
std::mutex& Mutex() { return mtx_; }
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
void RunFuncImpl(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
CudnnHolder::~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
}
friend class CudnnWorkspaceHandle;
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
CudnnWorkspaceHandle::CudnnWorkspaceHandle(CudnnHolder* holder)
: holder_(holder) {}
void CudnnWorkspaceHandle::RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
// defer lock when the function is invoked first time
BeginCallGuard();
holder_->RunFuncImpl(cudnn_func, required_workspace_len);
}
void CudnnWorkspaceHandle::BeginCallGuard() {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
}
void CudnnWorkspaceHandle::EndCallGuard() { guard_.reset(); }
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
SetDeviceId(place_.device);
......@@ -300,11 +249,6 @@ CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder_.get());
}
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
......
......@@ -73,29 +73,55 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
class CudnnHolder;
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
~CudnnHolder();
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
private:
friend class CudnnWorkspaceHandle;
void ReallocateWorkspace(size_t required_workspace_len);
template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
class CudnnWorkspaceHandle {
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
explicit CudnnWorkspaceHandle(CudnnHolder* holder);
inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len);
/*! \brief User can call this method to acquire the lock manually,
* But it is usually unnecessary, because RunFunc() would
* acquire the lock first before invoking cudnn functions. */
void BeginCallGuard();
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
}
/*! \brief User can call this method to release the lock manually,
* But it is usually unnecssary, because the lock would be
* release once the handle is destructed. But it can be used
* to manually release the lock as soon as possible. */
void EndCallGuard();
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
private:
CudnnHolder* holder_; // not own
......@@ -137,11 +163,6 @@ class CUDADeviceContext : public DeviceContext {
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
size_t workspace_len) const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
......
......@@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property("enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequential_execution_ = b;
})
.def_property(
"enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
.def_property(
"remove_unnecessary_lock",
[](const BuildStrategy &self) {
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
......
......@@ -18,6 +18,7 @@ import multiprocessing
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import time
import numpy as np
import math
......@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册