提交 45559d04 编写于 作者: S sneaxiy

move to pass

test=develop
上级 a314a80c
......@@ -33,13 +33,15 @@ if(WITH_GPU)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
endif()
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
......@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
}
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
......@@ -95,11 +100,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->Erase("enable_sequential_execution");
if (enable_sequential_execution_) {
pass->Set("enable_sequential_execution", new bool(true));
}
pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name");
......@@ -115,6 +115,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "sequential_execution_pass") {
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
}
graph = pass->Apply(std::move(graph));
}
......@@ -129,3 +134,4 @@ USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
......@@ -20,12 +20,11 @@ namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place, size_t place_id)
platform::Place place)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place),
place_id_(place_id) {}
place_(place) {}
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
......
......@@ -28,8 +28,7 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t place_id);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
std::string Name() const override;
......@@ -37,10 +36,6 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; }
const OperatorBase &GetOp() const { return *op_; }
size_t GetPlaceId() const { return place_id_; }
protected:
void RunImpl() override;
......@@ -50,7 +45,6 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
size_t place_id_;
};
} // namespace details
} // namespace framework
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include <algorithm>
#include <fstream>
#include <map>
#include <string>
#include <utility>
#include <vector>
......@@ -238,24 +237,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(
const ir::Graph &graph, bool enable_sequential_execution = false) {
std::vector<ir::Node *> ret;
if (enable_sequential_execution) {
VLOG(10) << "sequential execution mode is enabled";
for (auto *node : graph.Nodes()) {
if (node->IsOp()) {
ret.push_back(node);
}
}
std::sort(ret.begin(), ret.end(),
[](const ir::Node *n1, const ir::Node *n2) {
return n1->id() < n2->id();
});
} else {
ret = ir::TopologySortOperations(graph);
}
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
size_t last_backward = 0;
for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>(
......@@ -304,10 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure.
bool enable_sequential_execution = Has("enable_sequential_execution") &&
Get<bool>("enable_sequential_execution");
std::vector<ir::Node *> sorted_ops =
SortOpsAndDelayOptimizeOp(*graph, enable_sequential_execution);
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
......@@ -463,12 +443,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
// Insert dependencies between computation_ops
if (enable_sequential_execution) {
InsertSequenceDependenciesBetweenComputationOps(graph.get());
}
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
......@@ -483,34 +457,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph;
}
void MultiDevSSAGraphBuilder::InsertSequenceDependenciesBetweenComputationOps(
ir::Graph *graph) const {
auto &ops = graph->Get<GraphOps>(kGraphOps);
// Use std::map instead of std::unordered_map for better log message
std::map<size_t, std::vector<ComputationOpHandle *>> compute_ops;
for (auto &op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
compute_ops[compute_op->GetPlaceId()].push_back(compute_op);
}
for (auto &pair : compute_ops) {
auto &ops = pair.second;
for (size_t i = 1; i < ops.size(); ++i) {
if (ops[i - 1]->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
ops[i - 1]->AddOutput(dep_var);
}
ops[i]->AddInput(ops[i - 1]->Outputs().front());
VLOG(10) << "sequential execution mode: device(" << pair.first
<< ") insert dependency between "
<< ops[i - 1]->GetOp().DebugString() << " -> "
<< ops[i]->GetOp().DebugString();
}
}
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
......@@ -567,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id));
local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id);
}
......@@ -684,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
result->CreateOpNode(node->Op()), s, p, scope_idx));
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx);
}
}
......
......@@ -86,8 +86,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
void InsertSequenceDependenciesBetweenComputationOps(ir::Graph *graph) const;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto ops = this->Get<const std::vector<OpDesc *>>(kAllOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
pending_ops[node];
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
found_node->Op()->Type());
for (auto *pending_op : pending_ops.at(found_node)) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
op_node_list.push_back(found_node);
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册