diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 8a3af0a8ebd5bad7be7046fa399cca4920da3d71..7f64bc75ae8ad40a268739cdc36051e76af9f49a 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT std::vector(outputs.begin(), outputs.end())); } +void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { + std::vector op_nodes; + for (auto &node : GraphTraits(graph).nodes_in_TS()) { + if (node.type() == Node::Type::kValue || node.deleted()) { + continue; + } + op_nodes.push_back(&node); + } + size_t op_num = op_nodes.size(); + for (size_t i = 0; i < op_num; i++) { + if (op_nodes[i]->type() == Node::Type::kFunction) continue; + std::unordered_set follow_up_input_names; + for (size_t j = i + 1; j < op_num; j++) { + for (auto *in : op_nodes[j]->inlinks) { + follow_up_input_names.insert(in->name()); + } + } + std::vector filtered_subgraph_outlinks; + for (auto *out : op_nodes[i]->outlinks) { + if (follow_up_input_names.count(out->name())) { + filtered_subgraph_outlinks.push_back(out); + } + } + PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL); + op_nodes[i]->outlinks = filtered_subgraph_outlinks; + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h index 16aeae4d35e7bd54646053190da7f47eaca69aa0..bb3ec6bbc1d9555386aba8837b019d2511653258 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.h +++ b/paddle/fluid/inference/analysis/data_flow_graph.h @@ -178,6 +178,7 @@ struct GraphTraits { std::pair, std::vector> ExtractInputAndOutputOfSubGraph(std::vector &graph); // NOLINT +void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph); } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index aaf7ca67011fb7bd4a74f6d8f57317594c528ca4..18c32fa09199003f17183207828cdfe4e627ae1a 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Finalize() { return true; } void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { + FilterRedundantOutputOfSubGraph(graph); LOG(INFO) << "graph.inputs " << graph->inputs.size(); for (auto &node : GraphTraits(graph).nodes_in_TS()) { if (node.deleted()) continue; diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc index a6f85484756417e103cbb60bcb664e8b800b9f28..c05b0e5d4690d0a447edf63a149903704bc2c9be 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc @@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { for (size_t i = 0; i < graph->nodes.size(); i++) { const Node &node = graph->nodes.Get(i); if (!config_.display_deleted_node && node.deleted()) continue; - for (auto &in : node.inlinks) { - if (!config_.display_deleted_node && in->deleted()) continue; - dot.AddEdge(in->repr(), node.repr(), {}); + for (auto &out : node.outlinks) { + if (!config_.display_deleted_node && out->deleted()) continue; + dot.AddEdge(node.repr(), out->repr(), {}); } } return dot.Build();