未验证 提交 3371f98b 编写于 作者: 周周周 提交者: GitHub

[Paddle Inference]add info in memory_optimize_pass.cc (#54789)

* add info in memory_optimize_pass.cc
上级 bdaf15ad
......@@ -59,6 +59,7 @@ void MemoryOptimizePass::CollectLifeCycle(
std::unordered_map<std::string, lifecycle_t>* lifecycles,
int sort_kind) const {
int max_lifecycle = 0;
double persis_byte = 0;
for (auto* op_node : framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(sort_kind))) {
if (!op_node->IsOp()) continue;
......@@ -80,7 +81,28 @@ void MemoryOptimizePass::CollectLifeCycle(
// Normal operators.
for (const Node* node : requires) {
if (!node->Var()) continue;
if (node->Var()->Persistable()) continue;
if (node->Var()->Persistable()) {
// "Getting 'tensor_desc' is not supported by the fetch type
// variable."
bool is_break = false;
for (auto op_op : node->inputs) {
if (op_op->Name() == "fetch") is_break = true;
}
if (is_break) continue;
auto in_shape = node->Var()->GetShape();
for (auto i : in_shape) {
CHECK_GT(i, 0);
}
auto var_bytes = std::accumulate(in_shape.begin(),
in_shape.end(),
(int64_t)1,
std::multiplies<int64_t>());
persis_byte +=
paddle::framework::SizeOfType(node->Var()->GetDataType()) *
var_bytes;
continue;
}
std::string var = node->Name();
if (!lifecycles->count(var)) {
(*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle);
......@@ -93,6 +115,8 @@ void MemoryOptimizePass::CollectLifeCycle(
++max_lifecycle;
}
LOG(INFO) << "The persistable params in main graph are : "
<< (persis_byte / (1 << 20)) << "MB";
}
void MemoryOptimizePass::CollectVarMemorySize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册