From bbf2bc2b8664d00a669a70a800b534d1976c38d7 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 28 Feb 2023 11:21:41 +0800 Subject: [PATCH] forbid tensorrt_engine op's output is a persistable var (#50932) * forbid tensorrt_engine op's output is a persistable var --- paddle/fluid/framework/ir/subgraph_detector.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/subgraph_detector.cc b/paddle/fluid/framework/ir/subgraph_detector.cc index 987194752ad..cce72ae161b 100644 --- a/paddle/fluid/framework/ir/subgraph_detector.cc +++ b/paddle/fluid/framework/ir/subgraph_detector.cc @@ -45,7 +45,9 @@ ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT } } for (auto *out : node->outputs) { - if (!nodes.count(out) && out->IsVar()) { + // we forbid out is a persistable var, for case when weight is shared + // between within and outside this tensorrt_engine op. + if (!nodes.count(out) && out->IsVar() && !out->Var()->Persistable()) { outputs.insert(out); } } @@ -416,7 +418,7 @@ void DetachDeletedNodes(framework::ir::Graph *graph) { void SubGraphFuser::ReplaceNodesWithSubGraphs() { auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)(); for (auto &subgraph : subgraphs) { - if (subgraph.size() <= (size_t)min_subgraph_size_) continue; + if (subgraph.size() <= static_cast(min_subgraph_size_)) continue; std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); // replace this sub-graph with the first node. Two steps: 1. Create a Block // Node that contains this subgraph 2. Mark the nodes inside the sub-graph -- GitLab