未验证 提交 5dc57b71 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #12593 from NHZlX/filter_redundant_output

filter redundant output
...@@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std::vector<Node *>(outputs.begin(), outputs.end())); std::vector<Node *>(outputs.begin(), outputs.end()));
} }
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
std::vector<Node *> op_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
if (node.type() == Node::Type::kValue || node.deleted()) {
continue;
}
op_nodes.push_back(&node);
}
size_t op_num = op_nodes.size();
for (size_t i = 0; i < op_num; i++) {
if (op_nodes[i]->type() == Node::Type::kFunction) continue;
std::unordered_set<std::string> follow_up_input_names;
for (size_t j = i + 1; j < op_num; j++) {
for (auto *in : op_nodes[j]->inlinks) {
follow_up_input_names.insert(in->name());
}
}
std::vector<Node *> filtered_subgraph_outlinks;
for (auto *out : op_nodes[i]->outlinks) {
if (follow_up_input_names.count(out->name())) {
filtered_subgraph_outlinks.push_back(out);
}
}
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
}
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -178,6 +178,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -178,6 +178,7 @@ struct GraphTraits<DataFlowGraph> {
std::pair<std::vector<Node *>, std::vector<Node *>> std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { ...@@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool DataFlowGraphToFluidPass::Finalize() { return true; } bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
FilterRedundantOutputOfSubGraph(graph);
LOG(INFO) << "graph.inputs " << graph->inputs.size(); LOG(INFO) << "graph.inputs " << graph->inputs.size();
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
if (node.deleted()) continue; if (node.deleted()) continue;
......
...@@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { ...@@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
for (size_t i = 0; i < graph->nodes.size(); i++) { for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i); const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue; if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) { for (auto &out : node.outlinks) {
if (!config_.display_deleted_node && in->deleted()) continue; if (!config_.display_deleted_node && out->deleted()) continue;
dot.AddEdge(in->repr(), node.repr(), {}); dot.AddEdge(node.repr(), out->repr(), {});
} }
} }
return dot.Build(); return dot.Build();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册