From 21a45420f0dc4da9413b002d380e8b31bff88e05 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 22 Jul 2018 21:10:46 +0800 Subject: [PATCH] polish and test --- .../details/multi_devices_graph_builder.cc | 8 ++ .../fluid/framework/details/rpc_op_handle.cc | 4 +- .../framework/details/ssa_graph_builder.cc | 30 ----- .../framework/details/ssa_graph_builder.h | 9 -- paddle/fluid/framework/ir/CMakeLists.txt | 3 +- paddle/fluid/framework/ir/graph.cc | 4 + paddle/fluid/framework/ir/graph.h | 13 +- paddle/fluid/framework/ir/graph_helper.cc | 13 +- paddle/fluid/framework/ir/graph_helper.h | 4 + .../fluid/framework/ir/graph_helper_test.cc | 125 ++++++++++++++++++ paddle/fluid/operators/send_recv_util.h | 1 + 11 files changed, 166 insertions(+), 48 deletions(-) create mode 100644 paddle/fluid/framework/ir/graph_helper_test.cc diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 150f1534c8..de0cac0f1a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( const std::vector &var_names) const { int64_t numel_sum = 0; for (auto var_name : var_names) { + if (all_vars_.find(var_name) == all_vars_.end()) continue; auto var_desc = all_vars_.at(var_name); PADDLE_ENFORCE_NOT_NULL(var_desc); auto dim = framework::make_ddim(var_desc->GetShape()); @@ -271,6 +272,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { + // TODO(paddle-dev): Why is there no input for this op_handle? CreateScaleLossGradOp(&result); } // This assumes the backward generating code will ensure IsScaleLossOp @@ -288,6 +290,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( } else { // This op runs on all devices, and its output may have parameter's // gradients. + // TODO(paddle-dev): Why is so special about "read" op? if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { node->Op()->SetAttr("throw_eof_exp", false); CreateComputationalOps(&result, node, places_.size()); @@ -363,6 +366,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); + PADDLE_ENFORCE(!ir::HasCircle(result)); return graph; } @@ -620,6 +624,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, if (node->Op()->Type() == "split_byref" || node->Op()->Type() == "split_selected_rows") { + // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); @@ -657,7 +662,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { + // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); + PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), + "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index c52a07530e..f44b374edb 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/rpc_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { @@ -33,8 +34,7 @@ void RPCOpHandle::RunImpl() { for (auto *in : inputs_) { auto &p = static_cast(in)->place_; // FIXME(Yancey1989): need a better solution instead of use DebugString() - if (in->Node()->Name().find(ir::Node::kControlDepVarName) != - std::string::npos) { // HACK + if (ir::IsControlDepVar(*in->Node())) { // HACK continue; } if (in->GeneratedOp()) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index e7db8729cf..203d5fbbc1 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,36 +17,6 @@ namespace paddle { namespace framework { namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get("vars")) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - continue; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - OpHandleBase *write_op = (*it_new)->GeneratedOp(); - const auto &read_ops = (*it_old)->PendingOps(); - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. - continue; - } - - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - graph->Get("dep_vars").emplace(dep_var); - } - } - } - } -} - VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index f64445b470..e99bab518e 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -57,15 +57,6 @@ class SSAGraphBuilder : public ir::Pass { DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: - /** - * We only handle write after read(WAR), since it should not have a write - * after write in program. If there are write after write operators, we need - * prune them. - * - * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) - */ - static void PolishGraphToSupportDataHazards(ir::Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 744696ebb0..6447452ae5 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -2,4 +2,5 @@ cc_library(node SRCS node.cc DEPS proto_desc) cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node) -cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) +cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry) +cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 6ad8773567..c5b4851477 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -107,6 +107,10 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { } } } + +bool IsControlDepVar(const ir::Node &var) { + return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; +} } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index fccad9fd4f..4f59ec82a7 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -57,25 +57,34 @@ class Graph { const std::unordered_set &Nodes() const { return node_set_; } + // Create a normal variable with non-null VarDesc. ir::Node *CreateVarNode(VarDesc *var_desc) { return AddNode(new ir::Node(var_desc)); } + // Create a normal runnable operator with OpDesc. ir::Node *CreateOpNode(OpDesc *op_desc) { return AddNode(new ir::Node(op_desc)); } + // Create a control dependency var that connects 2 operations. The + // var doesn't hold any data. Other than that, it's no different from + // other var, considering dependency analysis. ir::Node *CreateControlDepVar() { - // TODO(panyx0718): control var name should be unique. + // TODO(panyx0718): control var name should be really unique. const std::string name = string::Sprintf( "%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); } + // A more free style way of creating a graph node. Mostly use for test + // or "copy" from another node. Avoid using it if possible. ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { return AddNode(new ir::Node(name, type)); } + // Clear all node information of the graph and return the ownership of the + // nodes. std::vector> ReleaseNodes() { std::vector> ret; for (auto &n : nodes_) { @@ -108,6 +117,8 @@ class Graph { std::map> nodes_; std::unordered_set node_set_; }; + +bool IsControlDepVar(const ir::Node &var); } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index b829cf204d..b1c19e6535 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -59,12 +59,9 @@ bool HasCircleHelper( in_trace->erase(node); return false; } -} // namespace - -bool HasCircle(const Graph &graph) { - std::map> adj_list = - BuildOperationAdjList(graph); +bool HasCircleInternal( + const std::map> &adj_list) { std::unordered_set visited; std::unordered_set in_trace; for (auto &adj : adj_list) { @@ -74,10 +71,16 @@ bool HasCircle(const Graph &graph) { } return false; } +} // namespace + +bool HasCircle(const Graph &graph) { + return HasCircleInternal(BuildOperationAdjList(graph)); +} std::vector TopologySortOperations(const Graph &graph) { std::map> adj_list = BuildOperationAdjList(graph); + PADDLE_ENFORCE(!HasCircleInternal(adj_list)); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 118c16ab7a..cd6c53a07f 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -24,10 +24,14 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { +// Test if the graph contains circle. bool HasCircle(const Graph &graph); +// Topology Sort the operations in the graph from inputs to outputs. +// `graph` cannot contain circle. std::vector TopologySortOperations(const Graph &graph); +// Build an adjacency list of operations for the `graph`. std::map> BuildOperationAdjList( const Graph &graph); diff --git a/paddle/fluid/framework/ir/graph_helper_test.cc b/paddle/fluid/framework/ir/graph_helper_test.cc new file mode 100644 index 0000000000..b517442bb7 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_helper_test.cc @@ -0,0 +1,125 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph.h" +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +void BuildCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o1->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o1); +} + +void BuildCircleGraph2(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + + o2->outputs.push_back(v2); + o1->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o1); +} + +void BuildNoCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); +} + +TEST(GraphHelperTest, Basic) { + ProgramDesc prog; + + Graph g(prog); + BuildCircleGraph(&g); + ASSERT_TRUE(HasCircle(g)); + + Graph g2(prog); + BuildCircleGraph2(&g2); + ASSERT_TRUE(HasCircle(g2)); + + auto adj_list = BuildOperationAdjList(g2); + for (auto& adj : adj_list) { + auto& adj_set = adj.second; + if (adj.first->Name() == "op1") { + ASSERT_EQ((*adj_set.begin())->Name(), "op2"); + } else if (adj.first->Name() == "op2") { + ASSERT_EQ((*adj_set.begin())->Name(), "op1"); + } else { + ASSERT_TRUE(false); + } + } + + Graph g3(prog); + BuildNoCircleGraph(&g3); + ASSERT_FALSE(HasCircle(g3)); + auto sorted = TopologySortOperations(g3); + std::map node_map; + for (size_t i = 0; i < sorted.size(); ++i) { + node_map[sorted[i]->Name()] = i; + } + ASSERT_EQ(node_map.at("op1"), 0); + ASSERT_EQ(node_map.at("op2"), 1); + ASSERT_TRUE(node_map.at("op3") < node_map.at("op5")); +} +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h index 500230d553..dc26c53c64 100644 --- a/paddle/fluid/operators/send_recv_util.h +++ b/paddle/fluid/operators/send_recv_util.h @@ -23,6 +23,7 @@ inline bool NeedSend(const framework::Scope& scope, const std::string& varname) { // dummy variable is only used in parallel executor to represent // some dependency relationship, we don't need to send/recv it. + // TODO(paddle-dev): Why would parallel executor logic leaked into here? if (varname.find(framework::ir::Node::kControlDepVarName) != std::string::npos) return false; -- GitLab