From 45559d042cd99ae2a328a826f8d4d674f7c29e44 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 29 Oct 2018 05:32:49 +0000 Subject: [PATCH] move to pass test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 6 +- .../fluid/framework/details/build_strategy.cc | 16 ++- .../details/computation_op_handle.cc | 5 +- .../framework/details/computation_op_handle.h | 8 +- .../details/multi_devices_graph_pass.cc | 66 ++----------- .../details/multi_devices_graph_pass.h | 2 - .../details/sequential_execution_pass.cc | 97 +++++++++++++++++++ .../details/sequential_execution_pass.h | 34 +++++++ 8 files changed, 155 insertions(+), 79 deletions(-) create mode 100644 paddle/fluid/framework/details/sequential_execution_pass.cc create mode 100644 paddle/fluid/framework/details/sequential_execution_pass.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index e0a3ef5a9..b832bc50a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -33,13 +33,15 @@ if(WITH_GPU) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) endif() +cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) + cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) if(WITH_GPU) - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass) else() - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass) endif() cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 469d2b25c..c6150465c 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { public: explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) : ir::PassBuilder(), strategy_(strategy) { + if (strategy_.enable_sequential_execution_) { + AppendPass("sequential_execution_pass"); + } + // Add a graph viz pass to record a graph. if (!strategy_.debug_graphviz_path_.empty()) { auto viz_pass = AppendPass("graph_viz_pass"); @@ -95,11 +100,6 @@ std::unique_ptr BuildStrategy::Apply( for (std::shared_ptr &pass : pass_builder_->AllPasses()) { if (pass->Type() == "multi_devices_pass") { - pass->Erase("enable_sequential_execution"); - if (enable_sequential_execution_) { - pass->Set("enable_sequential_execution", new bool(true)); - } - pass->Erase("places"); pass->SetNotOwned>("places", &places); pass->Erase("loss_var_name"); @@ -115,6 +115,11 @@ std::unique_ptr BuildStrategy::Apply( pass->Erase("nccl_ctxs"); pass->SetNotOwned("nccl_ctxs", nctx); #endif + } else if (pass->Type() == "sequential_execution_pass") { + pass->Erase(kAllOpDescs); + pass->Set>( + kAllOpDescs, + new std::vector(main_program.Block(0).AllOps())); } graph = pass->Apply(std::move(graph)); } @@ -129,3 +134,4 @@ USE_PASS(graph_viz_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); +USE_PASS(sequential_execution_pass); diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 95f114056..b6282debd 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -20,12 +20,11 @@ namespace paddle { namespace framework { namespace details { ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, - platform::Place place, size_t place_id) + platform::Place place) : OpHandleBase(node), op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), - place_(place), - place_id_(place_id) {} + place_(place) {} void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 0cf112bc4..e98f1ab14 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,8 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, - size_t place_id); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); std::string Name() const override; @@ -37,10 +36,6 @@ struct ComputationOpHandle : public OpHandleBase { const platform::Place &GetPlace() const { return place_; } - const OperatorBase &GetOp() const { return *op_; } - - size_t GetPlaceId() const { return place_id_; } - protected: void RunImpl() override; @@ -50,7 +45,6 @@ struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; - size_t place_id_; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index bccd91566..ebd1d644b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -13,7 +13,6 @@ // limitations under the License. #include #include -#include #include #include #include @@ -238,24 +237,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( // some optimizer ops might not depend on any nodes), we manually move all // optimizer nodes after last backward nodes. // However, the assumption by SSAGraphBuilder should be relaxed in the future. -std::vector SortOpsAndDelayOptimizeOp( - const ir::Graph &graph, bool enable_sequential_execution = false) { - std::vector ret; - if (enable_sequential_execution) { - VLOG(10) << "sequential execution mode is enabled"; - for (auto *node : graph.Nodes()) { - if (node->IsOp()) { - ret.push_back(node); - } - } - std::sort(ret.begin(), ret.end(), - [](const ir::Node *n1, const ir::Node *n2) { - return n1->id() < n2->id(); - }); - } else { - ret = ir::TopologySortOperations(graph); - } - +std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { + std::vector ret = ir::TopologySortOperations(graph); size_t last_backward = 0; for (size_t i = 0; i < ret.size(); ++i) { if (boost::get( @@ -304,10 +287,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr graph) const { Init(); // Give the topology sort order and rebuild the graph structure. - bool enable_sequential_execution = Has("enable_sequential_execution") && - Get("enable_sequential_execution"); - std::vector sorted_ops = - SortOpsAndDelayOptimizeOp(*graph, enable_sequential_execution); + std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; @@ -463,12 +443,6 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } } } - - // Insert dependencies between computation_ops - if (enable_sequential_execution) { - InsertSequenceDependenciesBetweenComputationOps(graph.get()); - } - /* Dependency graph has been constructed. However, there are still data hazards need to be handled. @@ -483,34 +457,6 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( return graph; } -void MultiDevSSAGraphBuilder::InsertSequenceDependenciesBetweenComputationOps( - ir::Graph *graph) const { - auto &ops = graph->Get(kGraphOps); - // Use std::map instead of std::unordered_map for better log message - std::map> compute_ops; - for (auto &op : ops) { - auto *compute_op = dynamic_cast(op.get()); - if (compute_op == nullptr) continue; - compute_ops[compute_op->GetPlaceId()].push_back(compute_op); - } - - for (auto &pair : compute_ops) { - auto &ops = pair.second; - for (size_t i = 1; i < ops.size(); ++i) { - if (ops[i - 1]->Outputs().empty()) { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dep_var); - ops[i - 1]->AddOutput(dep_var); - } - ops[i]->AddInput(ops[i - 1]->Outputs().front()); - VLOG(10) << "sequential execution mode: device(" << pair.first - << ") insert dependency between " - << ops[i - 1]->GetOp().DebugString() << " -> " - << ops[i]->GetOp().DebugString(); - } - } -} - bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { PADDLE_ENFORCE(all_vars_.count(og) != 0); if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { @@ -567,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), - local_scopes_[dev_id], places_[dev_id], dev_id)); + local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); } @@ -684,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get(kGraphOps).emplace_back(new ComputationOpHandle( - result->CreateOpNode(node->Op()), s, p, scope_idx)); + result->Get(kGraphOps).emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } } diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 6476a45d5..cdf9f13cd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -86,8 +86,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { void SetCommunicationContext(OpHandleBase *op_handle, const platform::Place &p) const; - void InsertSequenceDependenciesBetweenComputationOps(ir::Graph *graph) const; - mutable std::string loss_var_name_; mutable std::vector places_; mutable std::vector local_scopes_; diff --git a/paddle/fluid/framework/details/sequential_execution_pass.cc b/paddle/fluid/framework/details/sequential_execution_pass.cc new file mode 100644 index 000000000..6725cdfb2 --- /dev/null +++ b/paddle/fluid/framework/details/sequential_execution_pass.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/sequential_execution_pass.h" +#include +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) { + return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() && + op1->Outputs() == op2->Outputs(); +} + +std::unique_ptr SequentialExecutionPass::ApplyImpl( + std::unique_ptr graph) const { + auto ops = this->Get>(kAllOpDescs); + std::vector op_node_list; + op_node_list.reserve(ops.size()); + + std::unordered_map op_deps; + std::unordered_map> pending_ops; + std::unordered_set ready_ops; + + for (ir::Node *node : graph->Nodes()) { + if (!node->IsOp()) continue; + std::unordered_set preceding_ops; + pending_ops[node]; + for (auto *in : node->inputs) { + PADDLE_ENFORCE(in->IsVar(), + "Preceding Node of Op Nodes must be Var Node"); + if (in->inputs.empty()) continue; + PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(), + "Preceding Op Node of Var Node must be unique"); + preceding_ops.insert(in->inputs[0]); + pending_ops[in->inputs[0]].insert(node); + } + op_deps[node] = preceding_ops.size(); + if (preceding_ops.empty()) { + ready_ops.insert(node); + } + } + + for (auto *op_desc : ops) { + ir::Node *found_node = nullptr; + for (auto *node : ready_ops) { + if (IsSameOpDesc(op_desc, node->Op())) { + PADDLE_ENFORCE(found_node == nullptr, + "Found multiple op_desc in graph: %s", op_desc->Type()); + found_node = node; + } + } + + PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s", + found_node->Op()->Type()); + for (auto *pending_op : pending_ops.at(found_node)) { + if (--op_deps.at(pending_op) == 0) { + ready_ops.insert(pending_op); + } + } + ready_ops.erase(found_node); + op_node_list.push_back(found_node); + } + + for (size_t i = 1; i < op_node_list.size(); ++i) { + auto *dep_var = graph->CreateControlDepVar(); + op_node_list[i]->inputs.push_back(dep_var); + op_node_list[i - 1]->outputs.push_back(dep_var); + dep_var->outputs.push_back(op_node_list[i]); + dep_var->inputs.push_back(op_node_list[i - 1]); + VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name() + << " and " << op_node_list[i]->Name(); + } + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(sequential_execution_pass, + paddle::framework::details::SequentialExecutionPass) + .RequirePassAttr(paddle::framework::details::kAllOpDescs); diff --git a/paddle/fluid/framework/details/sequential_execution_pass.h b/paddle/fluid/framework/details/sequential_execution_pass.h new file mode 100644 index 000000000..a04c08bc2 --- /dev/null +++ b/paddle/fluid/framework/details/sequential_execution_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +constexpr char kAllOpDescs[] = "all_op_descs"; + +class SequentialExecutionPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle -- GitLab