提交 45559d04 编写于 作者: S sneaxiy

move to pass

test=develop
上级 a314a80c
...@@ -33,13 +33,15 @@ if(WITH_GPU) ...@@ -33,13 +33,15 @@ if(WITH_GPU)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif() endif()
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
if(WITH_GPU) if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass)
else() else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
endif() endif()
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
...@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public: public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
}
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
...@@ -95,11 +100,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -95,11 +100,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") { if (pass->Type() == "multi_devices_pass") {
pass->Erase("enable_sequential_execution");
if (enable_sequential_execution_) {
pass->Set("enable_sequential_execution", new bool(true));
}
pass->Erase("places"); pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places); pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name"); pass->Erase("loss_var_name");
...@@ -115,6 +115,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -115,6 +115,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs"); pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "sequential_execution_pass") {
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} }
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
} }
...@@ -129,3 +134,4 @@ USE_PASS(graph_viz_pass); ...@@ -129,3 +134,4 @@ USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass); USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
...@@ -20,12 +20,11 @@ namespace paddle { ...@@ -20,12 +20,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place, size_t place_id) platform::Place place)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())), op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place), place_(place) {}
place_id_(place_id) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
......
...@@ -28,8 +28,7 @@ namespace framework { ...@@ -28,8 +28,7 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
size_t place_id);
std::string Name() const override; std::string Name() const override;
...@@ -37,10 +36,6 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -37,10 +36,6 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
const OperatorBase &GetOp() const { return *op_; }
size_t GetPlaceId() const { return place_id_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -50,7 +45,6 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -50,7 +45,6 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
size_t place_id_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <map>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -238,24 +237,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -238,24 +237,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
// some optimizer ops might not depend on any nodes), we manually move all // some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes. // optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future. // However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp( std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
const ir::Graph &graph, bool enable_sequential_execution = false) { std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
std::vector<ir::Node *> ret;
if (enable_sequential_execution) {
VLOG(10) << "sequential execution mode is enabled";
for (auto *node : graph.Nodes()) {
if (node->IsOp()) {
ret.push_back(node);
}
}
std::sort(ret.begin(), ret.end(),
[](const ir::Node *n1, const ir::Node *n2) {
return n1->id() < n2->id();
});
} else {
ret = ir::TopologySortOperations(graph);
}
size_t last_backward = 0; size_t last_backward = 0;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>( if (boost::get<int>(
...@@ -304,10 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -304,10 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
// Give the topology sort order and rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
bool enable_sequential_execution = Has("enable_sequential_execution") && std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
Get<bool>("enable_sequential_execution");
std::vector<ir::Node *> sorted_ops =
SortOpsAndDelayOptimizeOp(*graph, enable_sequential_execution);
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
...@@ -463,12 +443,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -463,12 +443,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
} }
} }
// Insert dependencies between computation_ops
if (enable_sequential_execution) {
InsertSequenceDependenciesBetweenComputationOps(graph.get());
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
hazards need to be handled. hazards need to be handled.
...@@ -483,34 +457,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -483,34 +457,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph; return graph;
} }
void MultiDevSSAGraphBuilder::InsertSequenceDependenciesBetweenComputationOps(
ir::Graph *graph) const {
auto &ops = graph->Get<GraphOps>(kGraphOps);
// Use std::map instead of std::unordered_map for better log message
std::map<size_t, std::vector<ComputationOpHandle *>> compute_ops;
for (auto &op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
compute_ops[compute_op->GetPlaceId()].push_back(compute_op);
}
for (auto &pair : compute_ops) {
auto &ops = pair.second;
for (size_t i = 1; i < ops.size(); ++i) {
if (ops[i - 1]->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
ops[i - 1]->AddOutput(dep_var);
}
ops[i]->AddInput(ops[i - 1]->Outputs().front());
VLOG(10) << "sequential execution mode: device(" << pair.first
<< ") insert dependency between "
<< ops[i - 1]->GetOp().DebugString() << " -> "
<< ops[i]->GetOp().DebugString();
}
}
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0); PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
...@@ -567,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -567,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id)); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
...@@ -684,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -684,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(
result->CreateOpNode(node->Op()), s, p, scope_idx)); new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
} }
......
...@@ -86,8 +86,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -86,8 +86,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
void InsertSequenceDependenciesBetweenComputationOps(ir::Graph *graph) const;
mutable std::string loss_var_name_; mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_; mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_; mutable std::vector<Scope *> local_scopes_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto ops = this->Get<const std::vector<OpDesc *>>(kAllOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
pending_ops[node];
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
found_node->Op()->Type());
for (auto *pending_op : pending_ops.at(found_node)) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
op_node_list.push_back(found_node);
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册