diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 62f94a1c0e5a300438bbe5fea34b9a07df5d9ebf..c54766d95a61ac1a4b61566c6de62cbc86685a1d 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/ir/graph_helper.h" #include +#include #include -#include "paddle/fluid/framework/ir/graph_helper.h" - namespace paddle { namespace framework { namespace ir { @@ -113,6 +113,74 @@ std::map> BuildOperationAdjList( return adj_list; } +size_t GraphNum(const Graph &graph) { + std::unordered_set nodes = graph.Nodes(); + std::unordered_set visited_nodes; + visited_nodes.reserve(nodes.size()); + std::deque q_nodes; + std::vector> graph_nodes; + std::unordered_set g_nodes; + size_t graph_count = 0; + + auto traverse_nodes = [&visited_nodes, + &q_nodes](const std::vector &nodes) { + std::copy_if( + nodes.begin(), nodes.end(), std::back_inserter(q_nodes), + [&visited_nodes](Node *node) { return !visited_nodes.count(node); }); + }; + + while (visited_nodes.size() != nodes.size()) { + if (!q_nodes.empty()) { + auto cur_node = q_nodes.front(); + q_nodes.pop_front(); + visited_nodes.insert(cur_node); + g_nodes.insert(cur_node); + traverse_nodes(cur_node->inputs); + traverse_nodes(cur_node->outputs); + } else { + ++graph_count; + if (g_nodes.size()) { + graph_nodes.emplace_back(g_nodes); + } + g_nodes.clear(); + for (auto &n : nodes) { + if (visited_nodes.count(n) == 0) { + q_nodes.push_back(n); + break; + } + } + } + } + + if (g_nodes.size()) { + graph_nodes.emplace_back(g_nodes); + } + + if (VLOG_IS_ON(10)) { + VLOG(10) << "graph_num: " << graph_nodes.size(); + for (auto &g_n : graph_nodes) { + VLOG(10) << "graph_nodes: " << g_n.size(); + if (g_n.size() < 10) { + std::stringstream out; + for (auto &node : g_n) { + out << "\nNode: " << node->Name() << " in ["; + for (auto &n : node->inputs) { + out << n->Name() << ", "; + } + out << "], out["; + for (auto &n : node->outputs) { + out << n->Name() << ", "; + } + out << "]"; + } + VLOG(10) << out.str(); + } + } + } + + return graph_count; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index cd6c53a07f8f56781989739d995226bd02b3d3d0..ec46b38c01b8c369ab37b4fbd5497ec120d8db91 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -27,6 +27,8 @@ namespace ir { // Test if the graph contains circle. bool HasCircle(const Graph &graph); +size_t GraphNum(const Graph &graph); + // Topology Sort the operations in the graph from inputs to outputs. // `graph` cannot contain circle. std::vector TopologySortOperations(const Graph &graph); diff --git a/paddle/fluid/framework/ir/graph_helper_test.cc b/paddle/fluid/framework/ir/graph_helper_test.cc index a260dd3da2a7863c06e51aa4feafd824ea254139..cea902809339f9d45b0e2525163f08a3c1c44c95 100644 --- a/paddle/fluid/framework/ir/graph_helper_test.cc +++ b/paddle/fluid/framework/ir/graph_helper_test.cc @@ -120,6 +120,97 @@ TEST(GraphHelperTest, Basic) { ASSERT_EQ(node_map.at("op2"), 1UL); ASSERT_TRUE(node_map.at("op3") < node_map.at("op5")); } + +void BuildZeroGraph(Graph* g) {} + +void BuildOneGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); +} + +void BuildTwoGraphs(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + // o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + // v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + // o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + // v4->outputs.push_back(o5); +} + +TEST(GraphHelperTest, GraphNum) { + ProgramDesc prog; + + Graph g(prog); + BuildZeroGraph(&g); + ASSERT_EQ(GraphNum(g), 0); + + Graph g2(prog); + BuildOneGraph(&g2); + ASSERT_EQ(GraphNum(g2), 1); + + Graph g3(prog); + BuildTwoGraphs(&g3); + ASSERT_EQ(GraphNum(g3), 2); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 855870b41c9629aaae88f009ffa908a91b25a931..720d17a654bf96ca2bad43cc0c4374b2303ac233 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" - #include #include #include +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph.h" @@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor( params, member_->local_scopes_, member_->use_cuda_); #endif + // If the loss_var_name is given, the number of graph should be only one. + if (loss_var_name.size()) { + PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, + "The number of graph should be only one"); + } + if (exec_strategy.type_ == ExecutionStrategy::kDefault) { member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); diff --git a/python/paddle/fluid/tests/unittests/transformer_model.py b/python/paddle/fluid/tests/unittests/transformer_model.py index ab7a18d4c5c4ce1e490e2951ff9fbb023324e753..143d187edc3a154418f9e639b7d492c8ce994d42 100644 --- a/python/paddle/fluid/tests/unittests/transformer_model.py +++ b/python/paddle/fluid/tests/unittests/transformer_model.py @@ -246,6 +246,7 @@ def prepare_encoder(src_word, padding_idx=pos_pad_idx, param_attr=fluid.ParamAttr( name=pos_enc_param_name, trainable=False)) + src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc # FIXME(guosheng): Decouple the program desc with batch_size.