未验证 提交 2175d199 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix memory_reuse_pass memory_size calculation error, test=develop (#19020)

上级 de975be1
...@@ -100,8 +100,10 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const { ...@@ -100,8 +100,10 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const {
int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const { int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const {
auto *var_desc = GetVarDesc(var); auto *var_desc = GetVarDesc(var);
auto shapes = var_desc->GetShape(); auto shapes = var_desc->GetShape();
auto sizeof_dtype = static_cast<int64_t>(SizeOfType(var_desc->GetDataType()));
return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1), return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>()); std::multiplies<int64_t>()) *
sizeof_dtype;
} }
void MemoryReusePass::CollectShareTensorBufferOpHandles() const { void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册