未验证 提交 bbf2bc2b 编写于 作者: Z zhoutianzi666 提交者: GitHub

forbid tensorrt_engine op's output is a persistable var (#50932)

* forbid tensorrt_engine op's output is a persistable var
上级 8220771b
...@@ -45,7 +45,9 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -45,7 +45,9 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
} }
} }
for (auto *out : node->outputs) { for (auto *out : node->outputs) {
if (!nodes.count(out) && out->IsVar()) { // we forbid out is a persistable var, for case when weight is shared
// between within and outside this tensorrt_engine op.
if (!nodes.count(out) && out->IsVar() && !out->Var()->Persistable()) {
outputs.insert(out); outputs.insert(out);
} }
} }
...@@ -416,7 +418,7 @@ void DetachDeletedNodes(framework::ir::Graph *graph) { ...@@ -416,7 +418,7 @@ void DetachDeletedNodes(framework::ir::Graph *graph) {
void SubGraphFuser::ReplaceNodesWithSubGraphs() { void SubGraphFuser::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)(); auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) { for (auto &subgraph : subgraphs) {
if (subgraph.size() <= (size_t)min_subgraph_size_) continue; if (subgraph.size() <= static_cast<size_t>(min_subgraph_size_)) continue;
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end()); std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block // replace this sub-graph with the first node. Two steps: 1. Create a Block
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph // Node that contains this subgraph 2. Mark the nodes inside the sub-graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册