diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc index 1d71661da1ba0abca64deaf3f811ce5fde202d95..2ef119e4401c2ac5cdfcd1a2c7718a05bfab449f 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc @@ -100,8 +100,10 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const { int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const { auto *var_desc = GetVarDesc(var); auto shapes = var_desc->GetShape(); + auto sizeof_dtype = static_cast(SizeOfType(var_desc->GetDataType())); return std::accumulate(shapes.begin(), shapes.end(), static_cast(1), - std::multiplies()); + std::multiplies()) * + sizeof_dtype; } void MemoryReusePass::CollectShareTensorBufferOpHandles() const {