From afddcb97184ace40d863df8d90ae4429c51fa498 Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 6 Jul 2021 18:38:45 +0800 Subject: [PATCH] MemoryOptimizePass enhancement (#33933) * modify logic * test=allcase * test=document_fix --- .../analysis/passes/memory_optimize_pass.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index fdfd2c60af..7153163872 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -123,12 +123,27 @@ void MemoryOptimizePass::CollectVarMemorySize( } return true; }; + + // MemoryOptimizePass surppose input model is directed acyclic graph + // although it's not always the case. so black list is the best compromise + // between performance and underlying principle. + std::unordered_set black_list; + for (auto* node : graph_->Nodes()) { + if (node->IsVar() && + node->Var()->GetType() == + framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { + if (!valid_var(node)) { + black_list.emplace(node->Var()->Name()); + } + } + } + // Collect tensors from graph. for (auto* node : graph_->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && - valid_var(node)) { + !black_list.count(node->Var()->Name())) { // Parameters will not be reused. if (node->Var()->Persistable()) continue; auto shape = node->Var()->GetShape(); -- GitLab