提交 5be6f762 编写于 作者: S sneaxiy

remove_lock_in_some_ops

test=develop
上级 88376697
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_handle_graph SRCS op_handle_graph.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -28,6 +29,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -28,6 +29,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_handle_graph multi_devices_helper)
if(WITH_GPU) if(WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
...@@ -37,9 +40,9 @@ cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_ ...@@ -37,9 +40,9 @@ cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
if(WITH_GPU) if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass modify_op_lock_and_record_event_pass)
else() else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto modify_op_lock_and_wait_pass)
endif() endif()
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -20,18 +20,26 @@ namespace paddle { ...@@ -20,18 +20,26 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place) platform::Place place,
size_t scope_idx)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())), op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place) {} place_(place),
scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] { auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}); };
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
} }
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
......
...@@ -28,7 +28,8 @@ namespace framework { ...@@ -28,7 +28,8 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx);
std::string Name() const override; std::string Name() const override;
...@@ -36,6 +37,14 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +37,14 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
size_t GetScopeIdx() const { return scope_idx_; }
OperatorBase &GetOp() { return *op_; }
const OperatorBase &GetOp() const { return *op_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -45,6 +54,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -45,6 +54,8 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
size_t scope_idx_{0};
bool is_lock_and_record_event_free_{false};
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_graph.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) {
return dynamic_cast<ComputationOpHandle *>(op);
}
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpHandleGraph &graph) {
for (auto &pending_op : graph.PendingOps(op)) {
auto *tmp = ConvertToComputationOpHandle(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
}
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
OpHandleGraph graph(all_ops);
for (auto &op : all_ops) {
auto *compute_op = ConvertToComputationOpHandle(op.get());
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -513,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -513,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id])); local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
...@@ -630,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -630,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); result->CreateOpNode(node->Op()), s, p, scope_idx));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_handle_graph.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpHandleGraph::OpHandleGraph(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
BuildGraph(ops);
}
void OpHandleGraph::BuildGraph(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpHandleGraph::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpHandleGraph::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpHandleGraph::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpHandleGraph::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot found op %s in OpHandleGraph",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpHandleGraph::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpHandleGraph::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
std::vector<std::unordered_set<OpHandleBase *>> OpHandleGraph::AllPrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
std::queue<OpHandleBase *> queue[2];
int cur = 0;
std::unordered_set<OpHandleBase *> visited_ops;
std::vector<std::unordered_set<OpHandleBase *>> ret;
for (auto &tmp : preceding_ops_.at(op)) {
queue[cur].push(tmp);
visited_ops.insert(tmp);
}
while (!queue[cur].empty()) {
std::unordered_set<OpHandleBase *> cur_level_ops;
auto *tmp = queue[cur].front();
queue[cur].pop();
for (auto &preceding_op : preceding_ops_.at(tmp)) {
if (visited_ops.count(preceding_op)) {
continue;
} else {
queue[1 - cur].push(preceding_op);
cur_level_ops.insert(preceding_op);
visited_ops.insert(preceding_op);
}
}
if (!cur_level_ops.empty()) {
ret.emplace_back(std::move(cur_level_ops));
}
cur = 1 - cur;
}
return ret;
}
std::vector<std::unordered_set<OpHandleBase *>> OpHandleGraph::AllPendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
std::queue<OpHandleBase *> queue[2];
int cur = 0;
std::unordered_set<OpHandleBase *> visited_ops;
std::vector<std::unordered_set<OpHandleBase *>> ret;
for (auto &tmp : preceding_ops_.at(op)) {
queue[cur].push(tmp);
visited_ops.insert(tmp);
}
while (!queue[cur].empty()) {
std::unordered_set<OpHandleBase *> cur_level_ops;
auto *tmp = queue[cur].front();
queue[cur].pop();
for (auto &next_op : pending_ops_.at(tmp)) {
if (visited_ops.count(next_op)) {
continue;
} else {
queue[1 - cur].push(next_op);
cur_level_ops.insert(next_op);
visited_ops.insert(next_op);
}
}
if (!cur_level_ops.empty()) {
ret.emplace_back(std::move(cur_level_ops));
}
cur = 1 - cur;
}
return ret;
}
OpHandleGraph::Relation OpHandleGraph::RelationBetween(
OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
if (op1 == op2) {
return kSame;
} else if (IsBeforeOrSameImpl(op1, op2)) {
return kBefore;
} else if (IsBeforeOrSameImpl(op2, op1)) {
return kAfter;
} else {
return kNoDeps;
}
}
bool OpHandleGraph::IsSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return op1 == op2;
}
bool OpHandleGraph::IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return IsBeforeOrSameImpl(op1, op2);
}
bool OpHandleGraph::IsBefore(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return op1 != op2 && IsBeforeOrSameImpl(op1, op2);
}
bool OpHandleGraph::IsBeforeOrSameImpl(OpHandleBase *op1,
OpHandleBase *op2) const {
std::queue<OpHandleBase *> queue;
// BFS
queue.push(op1);
do {
auto *op = queue.front();
queue.pop();
if (op == op2) return true;
for (auto &pending_op : pending_ops_.at(op)) {
queue.push(pending_op);
}
} while (!queue.empty());
return false;
}
bool OpHandleGraph::IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return IsBeforeOrSameImpl(op2, op1);
}
bool OpHandleGraph::IsAfter(OpHandleBase *op1, OpHandleBase *op2) const {
return IsBefore(op2, op1);
}
bool OpHandleGraph::IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const {
return RelationBetween(op1, op2) == kNoDeps;
}
std::unordered_set<OpHandleBase *> OpHandleGraph::NoPendingOpSet() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : pending_ops_) {
if (pair.second.empty()) ret.insert(pair.first);
}
return ret;
}
std::unordered_set<OpHandleBase *> OpHandleGraph::NoPrecedingOpSet() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
if (pair.second.empty()) ret.insert(pair.first);
}
return ret;
}
OpHandleBase *OpHandleGraph::NearestCommonParent(OpHandleBase *op1,
OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
// FIXME(zjl): A brute-force O(2*n) algorithm here
// First, BFS all preceding_ops of op1 and record them in set S
// Second, BFS all preceding_ops of op2 and found whether it is in set S
std::unordered_set<OpHandleBase *> all_preceding_ops;
std::queue<OpHandleBase *> queue;
queue.push(op1);
do {
auto *op = queue.front();
queue.pop();
all_preceding_ops.insert(op);
for (auto &preceding_op : preceding_ops_.at(op)) {
queue.push(preceding_op);
}
} while (!queue.empty());
queue.push(op2);
do {
auto *op = queue.front();
queue.pop();
if (all_preceding_ops.count(op)) return op;
for (auto &preceding_op : preceding_ops_.at(op)) {
queue.push(preceding_op);
}
} while (!queue.empty());
return nullptr;
}
OpHandleBase *OpHandleGraph::NearestCommonParentAfter(OpHandleBase *op,
OpHandleBase *op1,
OpHandleBase *op2) const {
EnforceHasOp(op);
EnforceHasOp(op1);
EnforceHasOp(op2);
std::unordered_map<OpHandleBase *, int> all_preceding_ops;
int max_depth = -1;
std::queue<std::pair<OpHandleBase *, int>> queue;
queue.push(std::make_pair(op1, 0));
do {
auto tmp = queue.front();
queue.pop();
all_preceding_ops.insert(tmp);
if (tmp.first == op1) {
max_depth = tmp.second;
break;
}
for (auto &preceding_op : preceding_ops_.at(tmp.first)) {
queue.push(std::make_pair(preceding_op, tmp.second + 1));
}
} while (!queue.empty());
if (max_depth == -1) {
return nullptr;
}
std::queue<OpHandleBase *> queue2;
queue2.push(op2);
do {
auto *tmp = queue2.front();
queue2.pop();
if (all_preceding_ops.count(tmp) &&
(tmp == op || all_preceding_ops[tmp] < max_depth)) {
return tmp;
}
} while (!queue2.empty());
return nullptr;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
class OpHandleGraph {
public:
enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 };
explicit OpHandleGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
std::vector<std::unordered_set<OpHandleBase *>> AllPrecedingOps(
OpHandleBase *op) const;
std::vector<std::unordered_set<OpHandleBase *>> AllPendingOps(
OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const;
OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const;
// Find an operator that is after op and before op1, op2
OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1,
OpHandleBase *op2) const;
std::unordered_set<OpHandleBase *> NoPendingOpSet() const;
std::unordered_set<OpHandleBase *> NoPrecedingOpSet() const;
private:
void BuildGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void EnforceHasOp(OpHandleBase *op) const;
bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
pending_ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device)); platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
} }
...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~ReferenceCountOpHandle() { ~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_)); PADDLE_ENFORCE(cudaEventDestroy(event_));
} }
} }
......
...@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { ...@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return nullptr; return nullptr;
} }
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount); auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
...@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) { AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
} }
} }
...@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) { AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(compute_op->Outputs().front());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
} }
......
...@@ -156,6 +156,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,6 +156,10 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
graph = ir::PassRegistry::Instance()
.Get("modify_op_lock_and_record_event_pass")
->Apply(std::move(graph));
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
...@@ -319,6 +323,8 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -319,6 +323,8 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(modify_op_lock_and_record_event_pass);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass); USE_PASS(reference_count_pass);
#endif #endif
...@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) { auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
...@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out)); &beta, cudnn_output_desc, output_data + i * group_offset_out));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
}; };
...@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
...@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in)); cudnn_input_desc, input_grad_data + i * group_offset_in));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
// ------------------- cudnn conv backward filter --------------------- // ------------------- cudnn conv backward filter ---------------------
...@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
} }
......
...@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset = output->numel() / output->dims()[0] / groups; int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) { auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
...@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo, cudnn_workspace, workspace_size_in_bytes, &beta, algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g)); cudnn_output_desc, output_data + output_offset * g));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
}; };
...@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad->numel() / output_grad->dims()[0] / groups; output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
...@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g)); input_grad_data + input_offset * g));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
...@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g)); cudnn_filter_desc, filter_grad_data + filter_offset * g));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
} }
......
...@@ -168,10 +168,7 @@ class CudnnHolder { ...@@ -168,10 +168,7 @@ class CudnnHolder {
void RunFunc(const std::function<void(void*)>& cudnn_func, void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) { size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_); std::lock_guard<std::mutex> lock(mtx_);
if (required_workspace_len > workspace_len_) { RunFuncImpl(cudnn_func, required_workspace_len);
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
} }
~CudnnHolder() { ~CudnnHolder() {
...@@ -182,6 +179,16 @@ class CudnnHolder { ...@@ -182,6 +179,16 @@ class CudnnHolder {
} }
private: private:
std::mutex& Mutex() { return mtx_; }
void RunFuncImpl(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
void ReallocateWorkspace(size_t required_workspace_len) { void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) { if (required_workspace_len <= workspace_len_) {
return; return;
...@@ -195,6 +202,8 @@ class CudnnHolder { ...@@ -195,6 +202,8 @@ class CudnnHolder {
workspace_len_ = required_workspace_len; workspace_len_ = required_workspace_len;
} }
friend class CudnnWorkspaceHandle;
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
void* workspace_; void* workspace_;
size_t workspace_len_; size_t workspace_len_;
...@@ -205,6 +214,24 @@ class CudnnHolder { ...@@ -205,6 +214,24 @@ class CudnnHolder {
std::mutex mtx_; std::mutex mtx_;
}; };
CudnnWorkspaceHandle::CudnnWorkspaceHandle(CudnnHolder* holder)
: holder_(holder) {}
void CudnnWorkspaceHandle::RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
// defer lock when the function is invoked first time
BeginCallGuard();
holder_->RunFuncImpl(cudnn_func, required_workspace_len);
}
void CudnnWorkspaceHandle::BeginCallGuard() {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
}
void CudnnWorkspaceHandle::EndCallGuard() { guard_.reset(); }
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) { : place_(place), cudnn_holder_(nullptr) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
...@@ -271,6 +298,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { ...@@ -271,6 +298,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle(); return cudnn_holder_->cudnn_handle();
} }
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder_.get());
}
void CUDADeviceContext::RunCudnnFuncWithWorkspace( void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const { const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len); cudnn_holder_->RunFunc(cudnn_func, workspace_len);
......
...@@ -74,6 +74,33 @@ struct DefaultDeviceContextType<platform::CPUPlace> { ...@@ -74,6 +74,33 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CudnnHolder; class CudnnHolder;
class CudnnWorkspaceHandle {
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
explicit CudnnWorkspaceHandle(CudnnHolder* holder);
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len);
/*! \brief User can call this method to acquire the lock manually,
* But it is usually unnecessary, because RunFunc() would
* acquire the lock first before invoking cudnn functions. */
void BeginCallGuard();
/*! \brief User can call this method to release the lock manually,
* But it is usually unnecssary, because the lock would be
* release once the handle is destructed. But it can be used
* to manually release the lock as soon as possible. */
void EndCallGuard();
private:
CudnnHolder* holder_; // not own
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(CUDAPlace place); explicit CUDADeviceContext(CUDAPlace place);
...@@ -100,6 +127,15 @@ class CUDADeviceContext : public DeviceContext { ...@@ -100,6 +127,15 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Run a cudnn function with the workspace provided by /*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */ * CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func, void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册