From 93355cc0d222b558689f963805d21b366ee8699d Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sat, 21 Jul 2018 17:27:25 +0800 Subject: [PATCH] fix control deps --- paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../details/multi_devices_graph_builder.cc | 34 ++++++++------ .../details/multi_devices_graph_builder.h | 4 +- .../fluid/framework/details/rpc_op_handle.cc | 3 +- .../framework/details/ssa_graph_builder.cc | 33 +++++++++++++- .../framework/details/ssa_graph_builder.h | 9 ++++ paddle/fluid/framework/details/var_handle.cc | 2 +- paddle/fluid/framework/ir/graph.cc | 26 ++++++++--- paddle/fluid/framework/ir/graph.h | 45 ++++++++++++++++--- paddle/fluid/framework/ir/graph_helper.cc | 19 ++++---- paddle/fluid/framework/ir/graph_helper.h | 1 + paddle/fluid/framework/ir/graph_test.cc | 29 ++++++------ paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/send_recv_util.h | 5 ++- 14 files changed, 155 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 620d202d33..9df7df1f42 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto) +cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index f5e99c5748..150f1534c8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, } std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( - const std::vector> &nodes) const { + const std::vector &nodes) const { std::vector send_vars; // since parameters are all in block 0, // it's enough to only scan send ops in block 0 for (auto &node : nodes) { - if (node->NodeType() != ir::Node::Type::kOperation) continue; OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string @@ -114,10 +113,9 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( } std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( - const std::vector> &nodes) const { + const std::vector &nodes) const { std::vector recv_vars; for (auto &node : nodes) { - if (node->NodeType() != ir::Node::Type::kOperation) continue; OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string @@ -214,6 +212,19 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { } } + // Verify that no operations before optimize ops depends on optimize ops. + std::unordered_set optimize_set(optimize_ops.begin(), + optimize_ops.end()); + for (size_t i = 0; i < last_backward; ++i) { + for (ir::Node *in : sorted_ret[i]->inputs) { + for (ir::Node *pre_n : in->inputs) { + PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(), + "optimize operations cannot be depended by forward " + "or backward node %s -> %s", + pre_n->Name(), sorted_ret[i]->Name()); + } + } + } sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(), optimize_ops.end()); return sorted_ret; @@ -221,18 +232,16 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { - // Rebuild the graph structure. + // Give the topology sort order and rebuild the graph structure. std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); - auto nodes = std::move(graph->nodes); - graph->nodes.clear(); + auto nodes = graph->ReleaseNodes(); + ir::Graph &result = *graph; for (auto &node : nodes) { if (node->NodeType() == ir::Node::Type::kVariable) { all_vars_.emplace(node->Name(), node->Var()); } } - - ir::Graph &result = *graph; std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 @@ -242,8 +251,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // find send/recv vars so that we can place the distributed training // realted op in the place 0 - auto send_vars = FindDistTrainSendVars(nodes); - auto recv_vars = FindDistTrainRecvVars(nodes); + auto send_vars = FindDistTrainSendVars(sorted_ops); + auto recv_vars = FindDistTrainRecvVars(sorted_ops); std::vector> bcast_var_name_set; bcast_var_name_set.resize(places_.size()); @@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle( - result->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); + auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index c2c764bb94..55076f227b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &recv_vars) const; std::vector FindDistTrainSendVars( - const std::vector> &nodes) const; + const std::vector &nodes) const; std::vector FindDistTrainRecvVars( - const std::vector> &nodes) const; + const std::vector &nodes) const; void ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 924ff4d118..c52a07530e 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() { for (auto *in : inputs_) { auto &p = static_cast(in)->place_; // FIXME(Yancey1989): need a better solution instead of use DebugString() - if (in->DebugString() == "dummy") { // HACK + if (in->Node()->Name().find(ir::Node::kControlDepVarName) != + std::string::npos) { // HACK continue; } if (in->GeneratedOp()) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index dcdcb28ac4..e7db8729cf 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,6 +17,36 @@ namespace paddle { namespace framework { namespace details { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { + for (auto &var_map : graph->Get("vars")) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + continue; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->Get("dep_vars").emplace(dep_var); + } + } + } + } +} + VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { @@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle( - graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index e99bab518e..f64445b470 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass { DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: + /** + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ + static void PolishGraphToSupportDataHazards(ir::Graph *graph); + static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); diff --git a/paddle/fluid/framework/details/var_handle.cc b/paddle/fluid/framework/details/var_handle.cc index 6f00abd947..5457870e9f 100644 --- a/paddle/fluid/framework/details/var_handle.cc +++ b/paddle/fluid/framework/details/var_handle.cc @@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const { return ss.str(); } -std::string DummyVarHandle::DebugString() const { return "dummy"; } +std::string DummyVarHandle::DebugString() const { return node_->Name(); } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 769dddbc59..6ad8773567 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { std::map> var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); - + // For input args, reuse the same var name if it was created before. + // Otherwise, create a new one. for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; if (var_nodes.find(each_var_name) != var_nodes.end()) { @@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var = CreateVarNode(all_vars.at(each_var_name)); var_nodes[each_var_name].push_back(var); } else { - // TODO(paddle-dev): Seems some assumption doesn't hold? - VLOG(3) << op->Type() - << " input var not in all_var list: " << each_var_name; + // Operation input var can be optional (dispensable). Which means + // the operation doesn't really need the var at runtime. In this + // case, the no-existed var is ready at the beginning. var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); var_nodes[each_var_name].push_back(var); } node->inputs.push_back(var); var->outputs.push_back(node); } - + // For output args, always create a new var. for (auto &each_var_name : op->OutputArgumentNames()) { ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); var_nodes[each_var_name].push_back(var); @@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ + for (auto &var : var_nodes) { auto &versions = var.second; if (versions.size() <= 1) continue; @@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { // Read Write is the same op. continue; } - ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName, - ir::Node::Type::kVariable); + // 2 ops might have been connected via other vars. + bool has_dep = false; + for (ir::Node *r_out : read_op->outputs) { + for (ir::Node *w_in : write_op->inputs) { + if (r_out == w_in) { + has_dep = true; + break; + } + } + } + if (has_dep) continue; + ir::Node *dep_var = CreateControlDepVar(); read_op->outputs.push_back(dep_var); dep_var->inputs.push_back(read_op); write_op->inputs.push_back(dep_var); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index a1e39b08a4..fccad9fd4f 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -27,6 +27,7 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { + class Graph { public: explicit Graph(const ProgramDesc &program); @@ -54,28 +55,58 @@ class Graph { }; } + const std::unordered_set &Nodes() const { return node_set_; } + ir::Node *CreateVarNode(VarDesc *var_desc) { - nodes.emplace_back(new ir::Node(var_desc)); - return nodes.back().get(); + return AddNode(new ir::Node(var_desc)); } ir::Node *CreateOpNode(OpDesc *op_desc) { - nodes.emplace_back(new ir::Node(op_desc)); - return nodes.back().get(); + return AddNode(new ir::Node(op_desc)); + } + + ir::Node *CreateControlDepVar() { + // TODO(panyx0718): control var name should be unique. + const std::string name = string::Sprintf( + "%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); + return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); } ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { - nodes.emplace_back(new ir::Node(name, type)); - return nodes.back().get(); + return AddNode(new ir::Node(name, type)); } - std::vector> nodes; + std::vector> ReleaseNodes() { + std::vector> ret; + for (auto &n : nodes_) { + ret.emplace_back(n.second.release()); + } + nodes_.clear(); + node_set_.clear(); + return ret; + } private: + // This method takes ownership of `node`. + ir::Node *AddNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); + nodes_[node].reset(node); + node_set_.insert(node); + return node; + } + + void RemoveNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); + node_set_.erase(node); + nodes_.erase(node); + } + // NOTE: program_ shouldn't be exposed to user. const ProgramDesc &program_; std::map attrs_; std::map> attr_dels_; + std::map> nodes_; + std::unordered_set node_set_; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 76458be135..b829cf204d 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -33,9 +33,8 @@ void SortHelper( } } - LOG(ERROR) << "topology sort insert: " << node->Name() - << reinterpret_cast(node) << " input " - << node->inputs.size(); + VLOG(3) << "topology sort insert: " << node->Name() + << reinterpret_cast(node) << " input " << node->inputs.size(); ret->push_back(node); } @@ -93,18 +92,18 @@ std::map> BuildOperationAdjList( const Graph &graph) { std::map> adj_list; - for (auto &n : graph.nodes) { + for (auto &n : graph.Nodes()) { if (n->NodeType() != ir::Node::Type::kOperation) continue; - if (adj_list.find(n.get()) == adj_list.end()) { - adj_list[n.get()] = std::unordered_set(); + if (adj_list.find(n) == adj_list.end()) { + adj_list[n] = std::unordered_set(); } for (auto &var : n->inputs) { for (auto &adj_n : var->inputs) { PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); - adj_list[n.get()].insert(adj_n); - LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) - << " -> " << n->Name() << reinterpret_cast(n.get()) - << " via " << var->Name() << reinterpret_cast(var); + adj_list[n].insert(adj_n); + VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) + << " -> " << n->Name() << reinterpret_cast(n) + << " via " << var->Name() << reinterpret_cast(var); } } } diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 55b2e3f5ca..118c16ab7a 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -30,6 +30,7 @@ std::vector TopologySortOperations(const Graph &graph); std::map> BuildOperationAdjList( const Graph &graph); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index f8fbd2242d..feb4e8e76b 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -94,20 +94,21 @@ TEST(GraphTest, Basic) { prog.MutableBlock(0)->Var("test_out")->GetType()); std::unique_ptr g(new ir::Graph(prog)); - ASSERT_EQ(g->nodes[0]->Name(), "sum"); - ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); - ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); - ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c"); - ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out"); - ASSERT_EQ(g->nodes[1]->Name(), "test_a"); - ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum"); - ASSERT_EQ(g->nodes[2]->Name(), "test_b"); - ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum"); - ASSERT_EQ(g->nodes[3]->Name(), "test_c"); - ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum"); - ASSERT_EQ(g->nodes[4]->Name(), "test_out"); - ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum"); - ASSERT_EQ(g->nodes.size(), 5); + std::vector nodes(g->Nodes().begin(), g->Nodes().end()); + ASSERT_EQ(nodes[0]->Name(), "sum"); + ASSERT_EQ(nodes[0]->inputs[0]->Name(), "test_a"); + ASSERT_EQ(nodes[0]->inputs[1]->Name(), "test_b"); + ASSERT_EQ(nodes[0]->inputs[2]->Name(), "test_c"); + ASSERT_EQ(nodes[0]->outputs[0]->Name(), "test_out"); + ASSERT_EQ(nodes[1]->Name(), "test_a"); + ASSERT_EQ(nodes[1]->outputs[0]->Name(), "sum"); + ASSERT_EQ(nodes[2]->Name(), "test_b"); + ASSERT_EQ(nodes[2]->outputs[0]->Name(), "sum"); + ASSERT_EQ(nodes[3]->Name(), "test_c"); + ASSERT_EQ(nodes[3]->outputs[0]->Name(), "sum"); + ASSERT_EQ(nodes[4]->Name(), "test_out"); + ASSERT_EQ(nodes[4]->inputs[0]->Name(), "sum"); + ASSERT_EQ(nodes.size(), 5); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4e2002ad24..9b56ad4c55 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) + set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib) + set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h index deab005149..500230d553 100644 --- a/paddle/fluid/operators/send_recv_util.h +++ b/paddle/fluid/operators/send_recv_util.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/ir/node.h" namespace paddle { namespace operators { @@ -22,7 +23,9 @@ inline bool NeedSend(const framework::Scope& scope, const std::string& varname) { // dummy variable is only used in parallel executor to represent // some dependency relationship, we don't need to send/recv it. - if (varname == "dummy") return false; + if (varname.find(framework::ir::Node::kControlDepVarName) != + std::string::npos) + return false; auto* var = scope.FindVar(varname); PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", varname); -- GitLab