From 3371f98b5b2218e14c0b33d9d1f4b1c1d8f6c8c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 21 Jun 2023 14:52:41 +0800 Subject: [PATCH] [Paddle Inference]add info in memory_optimize_pass.cc (#54789) * add info in memory_optimize_pass.cc --- .../analysis/passes/memory_optimize_pass.cc | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index cb80dbffc0c..d6baea5e65c 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -59,6 +59,7 @@ void MemoryOptimizePass::CollectLifeCycle( std::unordered_map* lifecycles, int sort_kind) const { int max_lifecycle = 0; + double persis_byte = 0; for (auto* op_node : framework::ir::TopologyVarientSort( *graph, static_cast(sort_kind))) { if (!op_node->IsOp()) continue; @@ -80,7 +81,28 @@ void MemoryOptimizePass::CollectLifeCycle( // Normal operators. for (const Node* node : requires) { if (!node->Var()) continue; - if (node->Var()->Persistable()) continue; + if (node->Var()->Persistable()) { + // "Getting 'tensor_desc' is not supported by the fetch type + // variable." + bool is_break = false; + for (auto op_op : node->inputs) { + if (op_op->Name() == "fetch") is_break = true; + } + if (is_break) continue; + + auto in_shape = node->Var()->GetShape(); + for (auto i : in_shape) { + CHECK_GT(i, 0); + } + auto var_bytes = std::accumulate(in_shape.begin(), + in_shape.end(), + (int64_t)1, + std::multiplies()); + persis_byte += + paddle::framework::SizeOfType(node->Var()->GetDataType()) * + var_bytes; + continue; + } std::string var = node->Name(); if (!lifecycles->count(var)) { (*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle); @@ -93,6 +115,8 @@ void MemoryOptimizePass::CollectLifeCycle( ++max_lifecycle; } + LOG(INFO) << "The persistable params in main graph are : " + << (persis_byte / (1 << 20)) << "MB"; } void MemoryOptimizePass::CollectVarMemorySize( -- GitLab