diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4fb4ec38ee965a2790d11378a1ce6befa0ef5a00..e8057c35e8b957cb43e66937a5073a085c6e7708 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) -cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) +cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) @@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) +cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 368a1537115d0dbb59df95aa76201857d4cbc85d..2d7f18942890245249dd0619a40bb43833c9a2ee 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -27,6 +27,8 @@ namespace details { constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; +// Wraps ir::Node and provide helper utilities. +// It's responsible for populating necessary fields of ir::Node. class OpHandleBase { public: explicit OpHandleBase(ir::Node *node) : node_(node) {} diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 20df7a4722d589ffd168f842e927cff8411096bb..cbfbcb1c0cd24f16773f9633310166371600790c 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -17,6 +17,9 @@ #include #include #include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/scope.h" diff --git a/paddle/fluid/framework/details/ssa_graph.cc b/paddle/fluid/framework/details/ssa_graph.cc deleted file mode 100644 index 1b8c889449059c563ea39f86250075ac2537cdbe..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph.cc +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/ssa_graph.h" diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h deleted file mode 100644 index e996a00c162186e47e77d007503ac67caa9f8024..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/details/var_handle.h" - -namespace paddle { -namespace framework { -namespace details { - -// A SSA graph used by parallel executor. -struct SSAGraph { - // all variable in each devices. - // The outside vector is the device vector. Each element of this vector is a - // map from variable name to variables. The variables, who have the same name, - // will have a different version. The offset in the - // `std::vector>` is the version of varaibles. - std::vector< - std::unordered_map>>> - vars_; - - // aux variables to represent dependency. Useful to resolve data hazard. - std::unordered_set> dep_vars_; - - // all operators. NOTE that even we use a vector here, the operators is - // unordered. - std::vector> ops_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 884fc6455559c92c43dfc546e3de24204db75156..7de4426de8814551e779959f2bdc563c52bc51ed 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -81,9 +81,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { - GraphOps &all_ops = graph->Get("ops"); - - for (auto &op : all_ops) { + for (auto &op : graph->Get("ops")) { if (!op->Outputs().empty()) { continue; } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 9933bf32b7a1faa6841620662ba0781c19e47f54..87749009efe0c540a74b792cf233d10a7bf4ea69 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -18,7 +18,9 @@ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" @@ -29,10 +31,20 @@ namespace paddle { namespace framework { namespace details { +// all variable in each devices. +// The outside vector is the device vector. Each element of this vector is a +// map from variable name to variables. The variables, who have the same name, +// will have a differsent version. The offset in the +// `std::vector>` is the version of varaibles. typedef std::vector< std::unordered_map>>> GraphVars; + +// aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set> GraphDepVars; + +// all operators. NOTE that even we use a vector here, the operators is +// unordered. typedef std::vector> GraphOps; class SSAGraphBuilder : public ir::Pass { diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 6a211f52bb05f6ff705f43740a94543fe28d8fca..7c79d7f1e881c67514634d56caa715c41927dbce 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph.h" -#include #include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 958086033607a4ed8fb840f5b14fe5779625bd82..8815ec89b23bc874471eefde5fa855cd2a4bde1f 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -18,8 +18,8 @@ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 412b0a6ff2f982fbf609d14f24eed2d7c70750f9..6dd6fd262e35a192ba85eb3aa16660526d2ebca2 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/details/ssa_graph_printer.h" #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 8c9cb7cabb85ffa8ed59249303f883c0499afd5a..ac777883657e19cdb34f287f4a2d2dd1385fbf8a 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -173,9 +173,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - ir::Node *fetch_n = new ir::Node("fetch"); - auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); - temp_nodes->emplace_back(fetch_n); + temp_nodes->emplace_back(new ir::Node("fetch")); + auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, + &local_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -186,11 +186,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - ir::Node *dummy_n = new ir::Node("fetch"); - auto *fetch_dummy = new DummyVarHandle(dummy_n); + temp_nodes->emplace_back(new ir::Node("fetch")); + auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); - temp_nodes->emplace_back(dummy_n); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingOp(pending_ops, op); } diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 8bd3db9203ba1db1522eaa714023a74944f9549b..ae23e3b1f8d04d40f7ccffd22b6b3ac2a1462138 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -28,6 +28,9 @@ namespace framework { namespace details { class OpHandleBase; +// Wraps ir::Node and provide helper utilities. +// It's responsible for populating necessary fields of ir::Node. +// // VarHandleBase is the var node in the dependency graph. // A variable can only be generated by a single operator. i.e. // This is a single assignment graph. diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 14d697c509d9ae7d4d06373b712871300690dc03..1f6937658f99f62be80e8ea671eba4d69d8da189 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -19,37 +19,35 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr ProgramToGraph(const ProgramDesc &program) { - std::unique_ptr graph(new Graph(program)); - +Graph::Graph(const ProgramDesc &program) : program_(program) { std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } for (auto *op : program.Block(0).AllOps()) { - ir::Node *node = graph->CreateOpNode(op); + ir::Node *node = CreateOpNode(op); for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; if (all_vars.count(each_var_name) != 0) { - var = graph->CreateVarNode(all_vars.at(each_var_name)); + var = CreateVarNode(all_vars.at(each_var_name)); } else { - LOG(ERROR) << "input var not in all_var list: " << each_var_name; - var = graph->CreateEmptyNode(each_var_name); + // TODO(paddle-dev): Seems some assumption doesn't hold? + LOG(ERROR) << op->Type() + << " input var not in all_var list: " << each_var_name; + var = CreateEmptyNode(each_var_name); } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name)); + ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); node->outputs.push_back(var); var->inputs.push_back(node); } } - return std::move(graph); } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 8b185f96254a27851365461d3f16405d90f77083..93db573559799faadb3c6f245e5ead133735f510 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -29,7 +29,7 @@ namespace framework { class Graph { public: - explicit Graph(const ProgramDesc& program) : program_(program) {} + explicit Graph(const ProgramDesc& program); virtual ~Graph() { for (auto& attr : attrs_) { @@ -46,6 +46,7 @@ class Graph { template void Set(const std::string& attr_name, AttrType* attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; @@ -53,14 +54,6 @@ class Graph { }; } - template - AttrType* Erase(const std::string& attr_name) { - AttrType* attr = boost::any_cast(attrs_[attr_name]); - attrs_.erase(attr_name); - attr_dels_.erase(attr_name); - return attr; - } - ir::Node* CreateVarNode(VarDesc* var_desc) { nodes.emplace_back(new ir::Node(var_desc)); return nodes.back().get(); @@ -71,14 +64,14 @@ class Graph { return nodes.back().get(); } - // TODO(panyx0718): Need to handle CreateOpNode(nullptr). + // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph. + // node should either be a executable kOperation or a kVariable. kNone + // node is a temporary solution. ir::Node* CreateEmptyNode(const std::string& name) { nodes.emplace_back(new ir::Node(name)); return nodes.back().get(); } - std::vector inputs; - std::vector outputs; std::vector> nodes; private: @@ -88,7 +81,5 @@ class Graph { std::map> attr_dels_; }; -std::unique_ptr ProgramToGraph(const ProgramDesc& program); - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d2d08bc461a5719ca57e14f5e9e366860cdb7f79..cb1d524c34cf6a9263e47826d29eebe6f983f32b 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,17 +14,11 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include #include -#include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index c9014ffdf500e13158bd6e34fce9aa9f46b10904..1e5bba62b53025dacdbf2d74b35f266cf4e422c2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/ir/graph.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" @@ -132,7 +132,8 @@ ParallelExecutor::ParallelExecutor( #endif } builder_ = builder_factory.Create(); - std::unique_ptr graph = builder_->Apply(ProgramToGraph(main_program)); + std::unique_ptr graph(new Graph(main_program)); + graph = builder_->Apply(std::move(graph)); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(