diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/overview.md similarity index 97% rename from doc/fluid/design/ir/draft.md rename to doc/fluid/design/ir/overview.md index c29337cba1fe859e4968cb800e4e7d9ff6a54d31..83ef97c99efeaf27a27f93f0cd3857c0f1bc812e 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/overview.md @@ -177,8 +177,8 @@ graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah)); auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass"); mem_opt_pass.SetNotOwned("optimize_level", 1); mem_opt_pass->Apply(std::move(graph)); -graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah)); -graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_devices_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_devices_check_pass").Apply(std::move(grah)); Executor exe; exe.Run(graph); diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6440607dbe4666ff3ff91dc526465706b3b9c1f0..1d62792b80dd002b894da28be9162fc7d3ce054e 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -100,7 +100,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5d652d37307d0a55ffee14930ae180dcd3e27841..8f6c4163d6ee11fbe83f603f6148c2ac6175324d 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,9 +5,9 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper) -cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) -cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) +cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) +cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) +cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -28,7 +28,7 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle +cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc similarity index 95% rename from paddle/fluid/framework/details/ssa_graph_checker.cc rename to paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index b9e1cda1f24810009bc74a7abdf0156f723a1755..c9c255864a2477ed29873f8521acce37fa928c06 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include #include "paddle/fluid/framework/ir/graph.h" @@ -86,7 +86,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(multi_device_check_pass, +REGISTER_PASS(multi_devices_check_pass, paddle::framework::details::SSAGraghBuilderWithChecker) .RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars) diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/multi_devices_graph_check_pass.h similarity index 89% rename from paddle/fluid/framework/details/ssa_graph_checker.h rename to paddle/fluid/framework/details/multi_devices_graph_check_pass.h index 0e861ecb236361992d9883e3bd0e679f7563b539..1e2b1867c376956d7d2dac465c13e2f3f64ba7eb 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include @@ -22,7 +22,7 @@ namespace paddle { namespace framework { namespace details { -class SSAGraghBuilderWithChecker : public SSAGraphBuilder { +class SSAGraghBuilderWithChecker : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc similarity index 90% rename from paddle/fluid/framework/details/multi_devices_graph_builder.cc rename to paddle/fluid/framework/details/multi_devices_graph_pass.cc index a4fdbcb26d1d0cfb05edebff5419d9559c336b3a..c5a13e7e1f45e1eb9b4271880630c52d30022f4b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -21,7 +21,7 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" @@ -33,6 +33,92 @@ namespace paddle { namespace framework { namespace details { +namespace { +void PolishGraphToSupportDataHazards(ir::Graph *graph) { + for (auto &var_map : graph->Get(kGraphVars)) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + continue; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + bool has_dep = false; + for (auto *r_out : read_op->Outputs()) { + for (auto *w_in : write_op->Inputs()) { + if (r_out->Node() == w_in->Node()) { + has_dep = true; + break; + } + } + } + if (has_dep) continue; + + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + } + } + } + } +} + +VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, + const platform::Place &place, + size_t place_offset) { + auto &var_holders = graph->Get(kGraphVars)[place_offset]; + auto &var_holder = var_holders[node->Name()]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + if (node->Var()) { + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Name(), place); + } else { + var = new VarHandle( + graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, + place_offset, node->Name(), place); + } + var_holder.emplace_back(var); + } else { + var = var_holder.rbegin()->get(); + } + return var; +} + +void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, + ir::Node *new_node, const platform::Place &place, + size_t place_offset) { + auto &vars = + graph->Get(kGraphVars)[place_offset][new_node->Name()]; + size_t version = vars.size(); + auto var = + new VarHandle(new_node, version, place_offset, new_node->Name(), place); + vars.emplace_back(var); + op_handle->AddOutput(var); +} + +void AddOutputToLeafOps(ir::Graph *graph) { + for (auto &op : graph->Get(kGraphOps)) { + if (!op->Outputs().empty()) { + continue; + } + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dummy_leaf); + op->AddOutput(dummy_leaf); + } +} +} // namespace static const char kLossVarName[] = "loss_var_name"; static const char kPlaces[] = "places"; @@ -751,7 +837,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { } // namespace framework } // namespace paddle -REGISTER_PASS(multi_device_pass, +REGISTER_PASS(multi_devices_pass, paddle::framework::details::MultiDevSSAGraphBuilder) .RequirePassAttr(paddle::framework::details::kLossVarName) .RequirePassAttr(paddle::framework::details::kPlaces) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h similarity index 96% rename from paddle/fluid/framework/details/multi_devices_graph_builder.h rename to paddle/fluid/framework/details/multi_devices_graph_pass.h index f2cb6bb1c861e07f1034f1742ad4f3cfbb0d8837..7a6f238f9cf7af18cb10ea271e453fec1902c833 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -18,7 +18,7 @@ #include #include "paddle/fluid/framework/details/build_strategy.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph.h" namespace paddle { @@ -30,7 +30,7 @@ namespace framework { class Scope; namespace details { -class MultiDevSSAGraphBuilder : public SSAGraphBuilder { +class MultiDevSSAGraphBuilder : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc similarity index 95% rename from paddle/fluid/framework/details/ssa_graph_printer.cc rename to paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index ec3f31ab8d135efd2c77018e90cec46b25ca5e66..69944a42b688a9ea5ff29f75f18dd4b156848a27 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph_printer.h" +#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include #include "paddle/fluid/framework/ir/graph.h" @@ -82,5 +82,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, } // namespace framework } // namespace paddle -REGISTER_PASS(multi_device_print_pass, +REGISTER_PASS(multi_devices_print_pass, paddle::framework::details::SSAGraghBuilderWithPrinter); diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h similarity index 92% rename from paddle/fluid/framework/details/ssa_graph_printer.h rename to paddle/fluid/framework/details/multi_devices_graph_print_pass.h index 5eafd1805c3102dbd3cdfa68ee1495631c182b51..c00685fa1629c0722c315c726053c2cba8bf17e7 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h @@ -18,7 +18,7 @@ #include #include #include -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" namespace paddle { namespace framework { @@ -35,7 +35,7 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { void Print(const ir::Graph& graph, std::ostream& sout) const override; }; -class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { +class SSAGraghBuilderWithPrinter : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..0242274a16c50508f2c0294264c175515c7293ef --- /dev/null +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/details/multi_devices_helper.h" + +namespace paddle { +namespace framework { +namespace details {} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/multi_devices_helper.h similarity index 68% rename from paddle/fluid/framework/details/ssa_graph_builder.h rename to paddle/fluid/framework/details/multi_devices_helper.h index 53a4ad003d51a27a044d7a142434545eca0d5965..175c5a9950be69d7bf6ae9e386af762007a18a51 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -52,33 +52,6 @@ const char kGraphOps[] = "ops"; typedef std::unordered_map ShardedVarDevice; const char kShardedVarDevice[] = "sharded_var_device"; - -class SSAGraphBuilder : public ir::Pass { - public: - SSAGraphBuilder() {} - virtual ~SSAGraphBuilder() {} - - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); - - protected: - /* - Dependency graph has been constructed. However, there are still data - hazards need to be handled. - */ - static void PolishGraphToSupportDataHazards(ir::Graph *graph); - - static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, - const platform::Place &place, - size_t place_offset); - - // Add an output variable (each_var_name, place, place_offset) to op_handle, - // which belongs to graph - static void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, - ir::Node *new_node, const platform::Place &place, - size_t place_offset); - - static void AddOutputToLeafOps(ir::Graph *graph); -}; } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc deleted file mode 100644 index 575532540a624afde5f6dab25b11e9eac93c6448..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph_builder.h" -#include - -namespace paddle { -namespace framework { -namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get(kGraphVars)) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - continue; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - OpHandleBase *write_op = (*it_new)->GeneratedOp(); - const auto &read_ops = (*it_old)->PendingOps(); - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. - continue; - } - bool has_dep = false; - for (auto *r_out : read_op->Outputs()) { - for (auto *w_in : write_op->Inputs()) { - if (r_out->Node() == w_in->Node()) { - has_dep = true; - break; - } - } - } - if (has_dep) continue; - - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - } - } - } - } -} - -VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - ir::Graph *graph, ir::Node *node, const platform::Place &place, - size_t place_offset) { - auto &var_holders = graph->Get(kGraphVars)[place_offset]; - auto &var_holder = var_holders[node->Name()]; - VarHandle *var = nullptr; - if (var_holder.empty()) { - if (node->Var()) { - var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, - node->Name(), place); - } else { - var = new VarHandle( - graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, - place_offset, node->Name(), place); - } - var_holder.emplace_back(var); - } else { - var = var_holder.rbegin()->get(); - } - return var; -} - -void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, - ir::Node *new_node, - const platform::Place &place, - size_t place_offset) { - auto &vars = - graph->Get(kGraphVars)[place_offset][new_node->Name()]; - size_t version = vars.size(); - auto var = - new VarHandle(new_node, version, place_offset, new_node->Name(), place); - vars.emplace_back(var); - op_handle->AddOutput(var); -} - -void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { - for (auto &op : graph->Get(kGraphOps)) { - if (!op->Outputs().empty()) { - continue; - } - auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dummy_leaf); - op->AddOutput(dummy_leaf); - } -} -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 0eaf9a9c951991a5775604eb8d0e7535f81a4ae2..994bb6492f685138d02971a6caf12572aecd6d6f 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b5f01a9a2b76472063658f1a051a2ee3c65559b7..275cb8c592c3c0b153d31149570cd6596b9e1a7f 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -25,9 +25,9 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif +#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" +#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_checker.h" -#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -57,39 +57,39 @@ std::unique_ptr ApplyParallelExecutorPass( } // Convert graph to run on multi-devices. - auto multi_device_pass = - ir::PassRegistry::Instance().Get("multi_device_pass"); - multi_device_pass->SetNotOwned>("places", - &places); - multi_device_pass->SetNotOwned("loss_var_name", - &loss_var_name); - multi_device_pass->SetNotOwned>( + auto multi_devices_pass = + ir::PassRegistry::Instance().Get("multi_devices_pass"); + multi_devices_pass->SetNotOwned>("places", + &places); + multi_devices_pass->SetNotOwned("loss_var_name", + &loss_var_name); + multi_devices_pass->SetNotOwned>( "params", ¶m_names); - multi_device_pass->SetNotOwned>("local_scopes", - &local_scopes); - multi_device_pass->SetNotOwned("strategy", &strategy); + multi_devices_pass->SetNotOwned>("local_scopes", + &local_scopes); + multi_devices_pass->SetNotOwned("strategy", &strategy); #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; - multi_device_pass->SetNotOwned("nccl_ctxs", nctx); + multi_devices_pass->SetNotOwned("nccl_ctxs", nctx); #endif - graph = multi_device_pass->Apply(std::move(graph)); + graph = multi_devices_pass->Apply(std::move(graph)); // Apply a graph print pass to record a graph with device info. if (!strategy.debug_graphviz_path_.empty()) { - auto multi_device_print_pass = - ir::PassRegistry::Instance().Get("multi_device_print_pass"); - multi_device_print_pass->SetNotOwned( + auto multi_devices_print_pass = + ir::PassRegistry::Instance().Get("multi_devices_print_pass"); + multi_devices_print_pass->SetNotOwned( "debug_graphviz_path", &strategy.debug_graphviz_path_); - multi_device_print_pass->Set( + multi_devices_print_pass->Set( "graph_printer", new details::GraphvizSSAGraphPrinter); - graph = multi_device_print_pass->Apply(std::move(graph)); + graph = multi_devices_print_pass->Apply(std::move(graph)); } // Verify that the graph is correct for multi-device executor. - auto multi_device_check_pass = - ir::PassRegistry::Instance().Get("multi_device_check_pass"); - graph = multi_device_check_pass->Apply(std::move(graph)); + auto multi_devices_check_pass = + ir::PassRegistry::Instance().Get("multi_devices_check_pass"); + graph = multi_devices_check_pass->Apply(std::move(graph)); return graph; } @@ -354,6 +354,6 @@ ParallelExecutor::~ParallelExecutor() { } // namespace paddle USE_PASS(graph_viz_pass); -USE_PASS(multi_device_pass); -USE_PASS(multi_device_check_pass); -USE_PASS(multi_device_print_pass); +USE_PASS(multi_devices_pass); +USE_PASS(multi_devices_check_pass); +USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index d624956acde86cefc4ec1dec80df3738bcf1d8be..5fb748fa205d5e9dbd2943b615c69aedd0e7a26f 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -19,7 +19,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/details/execution_strategy.h" -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h"