提交 faac8a76 编写于 作者: S sneaxiy

remove unnecessary codes

test=develop
上级 7ff320f8
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_handle_graph SRCS op_handle_graph.cc DEPS op_handle_base) cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -31,9 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -31,9 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_handle_graph multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if(WITH_GPU) if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif() endif()
...@@ -43,7 +43,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap ...@@ -43,7 +43,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto modify_op_lock_and_record_event_pass sequential_execution_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
......
...@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
} }
private: private:
...@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass); ...@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
...@@ -73,6 +73,8 @@ struct BuildStrategy { ...@@ -73,6 +73,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API. // User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes // The PassBuilder allows for more customized insert, remove of passes
// from python side. // from python side.
......
...@@ -20,13 +20,11 @@ namespace paddle { ...@@ -20,13 +20,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place, platform::Place place)
size_t scope_idx)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())), op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place), place_(place) {}
scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
......
...@@ -28,8 +28,7 @@ namespace framework { ...@@ -28,8 +28,7 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
size_t scope_idx);
std::string Name() const override; std::string Name() const override;
...@@ -37,12 +36,6 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -37,12 +36,6 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
size_t GetScopeIdx() const { return scope_idx_; }
OperatorBase &GetOp() { return *op_; }
const OperatorBase &GetOp() const { return *op_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected: protected:
...@@ -54,7 +47,6 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -54,7 +47,6 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
size_t scope_idx_;
bool is_lock_and_record_event_free_{false}; bool is_lock_and_record_event_free_{false};
}; };
} // namespace details } // namespace details
......
...@@ -15,20 +15,17 @@ ...@@ -15,20 +15,17 @@
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h" #include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_graph.h" #include "paddle/fluid/framework/details/op_graph_view.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) {
return dynamic_cast<ComputationOpHandle *>(op);
}
static bool IsLockAndRecordEventFreeComputationOpHandle( static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpHandleGraph &graph) { ComputationOpHandle *op, const OpGraphView &graph_view) {
for (auto &pending_op : graph.PendingOps(op)) { if (!platform::is_gpu_place(op->GetPlace())) return false;
auto *tmp = ConvertToComputationOpHandle(pending_op); for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false; return false;
} }
...@@ -39,12 +36,12 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -39,12 +36,12 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const { std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps); auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
OpHandleGraph graph(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto *compute_op = ConvertToComputationOpHandle(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue; if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free = bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph); IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free); compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) { if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op " VLOG(10) << "Set is_lock_and_record_event_free be true in op "
......
...@@ -556,7 +556,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -556,7 +556,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id)); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
...@@ -672,8 +672,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -672,8 +672,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(
result->CreateOpNode(node->Op()), s, p, scope_idx)); new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -24,11 +24,9 @@ namespace paddle { ...@@ -24,11 +24,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class OpHandleGraph { class OpGraphView {
public: public:
enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 }; explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
explicit OpHandleGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const; size_t OpNumber() const;
...@@ -39,42 +37,11 @@ class OpHandleGraph { ...@@ -39,42 +37,11 @@ class OpHandleGraph {
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const; const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
std::vector<std::unordered_set<OpHandleBase *>> AllPrecedingOps(
OpHandleBase *op) const;
std::vector<std::unordered_set<OpHandleBase *>> AllPendingOps(
OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const; bool HasOp(OpHandleBase *op) const;
Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const;
bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const;
OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const;
// Find an operator that is after op and before op1, op2
OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1,
OpHandleBase *op2) const;
std::unordered_set<OpHandleBase *> NoPendingOpSet() const;
std::unordered_set<OpHandleBase *> NoPrecedingOpSet() const;
private: private:
void BuildGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops); void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void EnforceHasOp(OpHandleBase *op) const; void EnforceHasOp(OpHandleBase *op) const;
bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>> std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_; preceding_ops_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_handle_graph.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpHandleGraph::OpHandleGraph(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
BuildGraph(ops);
}
void OpHandleGraph::BuildGraph(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpHandleGraph::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpHandleGraph::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpHandleGraph::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpHandleGraph::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot found op %s in OpHandleGraph",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpHandleGraph::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpHandleGraph::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
std::vector<std::unordered_set<OpHandleBase *>> OpHandleGraph::AllPrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
std::queue<OpHandleBase *> queue[2];
int cur = 0;
std::unordered_set<OpHandleBase *> visited_ops;
std::vector<std::unordered_set<OpHandleBase *>> ret;
for (auto &tmp : preceding_ops_.at(op)) {
queue[cur].push(tmp);
visited_ops.insert(tmp);
}
while (!queue[cur].empty()) {
std::unordered_set<OpHandleBase *> cur_level_ops;
auto *tmp = queue[cur].front();
queue[cur].pop();
for (auto &preceding_op : preceding_ops_.at(tmp)) {
if (visited_ops.count(preceding_op)) {
continue;
} else {
queue[1 - cur].push(preceding_op);
cur_level_ops.insert(preceding_op);
visited_ops.insert(preceding_op);
}
}
if (!cur_level_ops.empty()) {
ret.emplace_back(std::move(cur_level_ops));
}
cur = 1 - cur;
}
return ret;
}
std::vector<std::unordered_set<OpHandleBase *>> OpHandleGraph::AllPendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
std::queue<OpHandleBase *> queue[2];
int cur = 0;
std::unordered_set<OpHandleBase *> visited_ops;
std::vector<std::unordered_set<OpHandleBase *>> ret;
for (auto &tmp : preceding_ops_.at(op)) {
queue[cur].push(tmp);
visited_ops.insert(tmp);
}
while (!queue[cur].empty()) {
std::unordered_set<OpHandleBase *> cur_level_ops;
auto *tmp = queue[cur].front();
queue[cur].pop();
for (auto &next_op : pending_ops_.at(tmp)) {
if (visited_ops.count(next_op)) {
continue;
} else {
queue[1 - cur].push(next_op);
cur_level_ops.insert(next_op);
visited_ops.insert(next_op);
}
}
if (!cur_level_ops.empty()) {
ret.emplace_back(std::move(cur_level_ops));
}
cur = 1 - cur;
}
return ret;
}
OpHandleGraph::Relation OpHandleGraph::RelationBetween(
OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
if (op1 == op2) {
return kSame;
} else if (IsBeforeOrSameImpl(op1, op2)) {
return kBefore;
} else if (IsBeforeOrSameImpl(op2, op1)) {
return kAfter;
} else {
return kNoDeps;
}
}
bool OpHandleGraph::IsSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return op1 == op2;
}
bool OpHandleGraph::IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return IsBeforeOrSameImpl(op1, op2);
}
bool OpHandleGraph::IsBefore(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return op1 != op2 && IsBeforeOrSameImpl(op1, op2);
}
bool OpHandleGraph::IsBeforeOrSameImpl(OpHandleBase *op1,
OpHandleBase *op2) const {
std::queue<OpHandleBase *> queue;
// BFS
queue.push(op1);
do {
auto *op = queue.front();
queue.pop();
if (op == op2) return true;
for (auto &pending_op : pending_ops_.at(op)) {
queue.push(pending_op);
}
} while (!queue.empty());
return false;
}
bool OpHandleGraph::IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
return IsBeforeOrSameImpl(op2, op1);
}
bool OpHandleGraph::IsAfter(OpHandleBase *op1, OpHandleBase *op2) const {
return IsBefore(op2, op1);
}
bool OpHandleGraph::IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const {
return RelationBetween(op1, op2) == kNoDeps;
}
std::unordered_set<OpHandleBase *> OpHandleGraph::NoPendingOpSet() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : pending_ops_) {
if (pair.second.empty()) ret.insert(pair.first);
}
return ret;
}
std::unordered_set<OpHandleBase *> OpHandleGraph::NoPrecedingOpSet() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
if (pair.second.empty()) ret.insert(pair.first);
}
return ret;
}
OpHandleBase *OpHandleGraph::NearestCommonParent(OpHandleBase *op1,
OpHandleBase *op2) const {
EnforceHasOp(op1);
EnforceHasOp(op2);
// FIXME(zjl): A brute-force O(2*n) algorithm here
// First, BFS all preceding_ops of op1 and record them in set S
// Second, BFS all preceding_ops of op2 and found whether it is in set S
std::unordered_set<OpHandleBase *> all_preceding_ops;
std::queue<OpHandleBase *> queue;
queue.push(op1);
do {
auto *op = queue.front();
queue.pop();
all_preceding_ops.insert(op);
for (auto &preceding_op : preceding_ops_.at(op)) {
queue.push(preceding_op);
}
} while (!queue.empty());
queue.push(op2);
do {
auto *op = queue.front();
queue.pop();
if (all_preceding_ops.count(op)) return op;
for (auto &preceding_op : preceding_ops_.at(op)) {
queue.push(preceding_op);
}
} while (!queue.empty());
return nullptr;
}
OpHandleBase *OpHandleGraph::NearestCommonParentAfter(OpHandleBase *op,
OpHandleBase *op1,
OpHandleBase *op2) const {
EnforceHasOp(op);
EnforceHasOp(op1);
EnforceHasOp(op2);
std::unordered_map<OpHandleBase *, int> all_preceding_ops;
int max_depth = -1;
std::queue<std::pair<OpHandleBase *, int>> queue;
queue.push(std::make_pair(op1, 0));
do {
auto tmp = queue.front();
queue.pop();
all_preceding_ops.insert(tmp);
if (tmp.first == op1) {
max_depth = tmp.second;
break;
}
for (auto &preceding_op : preceding_ops_.at(tmp.first)) {
queue.push(std::make_pair(preceding_op, tmp.second + 1));
}
} while (!queue.empty());
if (max_depth == -1) {
return nullptr;
}
std::queue<OpHandleBase *> queue2;
queue2.push(op2);
do {
auto *tmp = queue2.front();
queue2.pop();
if (all_preceding_ops.count(tmp) &&
(tmp == op || all_preceding_ops[tmp] < max_depth)) {
return tmp;
}
} while (!queue2.empty());
return nullptr;
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -118,10 +118,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -118,10 +118,6 @@ ParallelExecutor::ParallelExecutor(
main_program, member_->places_, loss_var_name, params, main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
graph = ir::PassRegistry::Instance()
.Get("modify_op_lock_and_record_event_pass")
->Apply(std::move(graph));
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
for (auto &place : member_->places_) { for (auto &place : member_->places_) {
...@@ -149,10 +145,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -149,10 +145,6 @@ ParallelExecutor::ParallelExecutor(
std::unique_ptr<ir::Graph> graph = std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name, build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
graph = ir::PassRegistry::Instance()
.Get("modify_op_lock_and_record_event_pass")
->Apply(std::move(graph));
#endif #endif
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
...@@ -331,8 +323,6 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -331,8 +323,6 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(modify_op_lock_and_record_event_pass);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass); USE_PASS(reference_count_pass);
#endif #endif
...@@ -153,83 +153,32 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -153,83 +153,32 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
}; };
class CudnnHolder { CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
public: : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); }
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_);
RunFuncImpl(cudnn_func, required_workspace_len);
}
~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
private:
std::mutex& Mutex() { return mtx_; }
void RunFuncImpl(const std::function<void(void*)>& cudnn_func, CudnnHolder::~CudnnHolder() {
size_t required_workspace_len) { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (required_workspace_len > workspace_len_) { if (workspace_ != nullptr) {
ReallocateWorkspace(required_workspace_len); paddle::memory::Free(place_, workspace_);
}
cudnn_func(workspace_);
} }
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
}
friend class CudnnWorkspaceHandle;
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
CudnnWorkspaceHandle::CudnnWorkspaceHandle(CudnnHolder* holder)
: holder_(holder) {}
void CudnnWorkspaceHandle::RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
// defer lock when the function is invoked first time
BeginCallGuard();
holder_->RunFuncImpl(cudnn_func, required_workspace_len);
} }
void CudnnWorkspaceHandle::BeginCallGuard() { void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (!guard_) { if (required_workspace_len <= workspace_len_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex())); return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
} }
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
} }
void CudnnWorkspaceHandle::EndCallGuard() { guard_.reset(); }
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) { : place_(place), cudnn_holder_(nullptr) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
...@@ -300,11 +249,6 @@ CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { ...@@ -300,11 +249,6 @@ CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder_.get()); return CudnnWorkspaceHandle(cudnn_holder_.get());
} }
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
......
...@@ -73,29 +73,55 @@ struct DefaultDeviceContextType<platform::CPUPlace> { ...@@ -73,29 +73,55 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CudnnHolder; class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
~CudnnHolder();
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
private:
friend class CudnnWorkspaceHandle;
void ReallocateWorkspace(size_t required_workspace_len);
template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
class CudnnWorkspaceHandle { class CudnnWorkspaceHandle {
public: public:
/*! \brief The lock would not be acquired when constructor calls. /*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */ * The lock would be acquired when RunFunc() is called first time. */
explicit CudnnWorkspaceHandle(CudnnHolder* holder); inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
/*! \brief Thread which call RunFunc() would acquire the lock first /*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */ * before invoking cudnn functions. */
void RunFunc(const std::function<void(void*)>& cudnn_func, template <typename Callback>
size_t required_workspace_len); inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
if (!guard_) {
/*! \brief User can call this method to acquire the lock manually, guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
* But it is usually unnecessary, because RunFunc() would }
* acquire the lock first before invoking cudnn functions. */ holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
void BeginCallGuard(); required_workspace_len);
}
/*! \brief User can call this method to release the lock manually, CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
* But it is usually unnecssary, because the lock would be CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
* release once the handle is destructed. But it can be used
* to manually release the lock as soon as possible. */
void EndCallGuard();
private: private:
CudnnHolder* holder_; // not own CudnnHolder* holder_; // not own
...@@ -137,11 +163,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -137,11 +163,6 @@ class CUDADeviceContext : public DeviceContext {
* sequential cudnn function calls. */ * sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const; CudnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
size_t workspace_len) const;
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const; cudaStream_t stream() const;
......
...@@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b; self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important }) // FIXME(chengudo): enable_data_balance seems not important
.def_property("enable_sequential_execution", .def_property(
[](const BuildStrategy &self) { "enable_sequential_execution",
return self.enable_sequential_execution_; [](const BuildStrategy &self) {
}, return self.enable_sequential_execution_;
[](BuildStrategy &self, bool b) { },
self.enable_sequential_execution_ = b; [](BuildStrategy &self, bool b) {
}) self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
.def_property(
"remove_unnecessary_lock",
[](const BuildStrategy &self) {
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
.def_property( .def_property(
"fuse_elewise_add_act_ops", "fuse_elewise_add_act_ops",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
......
...@@ -18,6 +18,7 @@ import multiprocessing ...@@ -18,6 +18,7 @@ import multiprocessing
import os import os
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import time import time
import numpy as np import numpy as np
import math import math
...@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册