diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 8e30b4dc4692ed3d71d9476b8a78a6d250bb2700..4fb06c654983b9e1b8441b074d5a30220fb960a2 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -67,6 +67,15 @@ class Tensor : public TensorBase { /*! The internal of two tensors share the same memory block. */ inline Tensor &ShareDataWith(const Tensor &src) { + src.check_memory_size(); + if (holder_.get() != src.holder_.get()) { + *this = src; + } + return *this; + } + + /*! The internal of two tensors share the same memory block. */ + inline Tensor &ShareHolderWith(const Tensor &src) { src.check_memory_size(); if (holder_.get() != src.holder_.get()) { holder_ = src.holder_; diff --git a/src/pass/memory_optimize.cpp b/src/pass/memory_optimize.cpp index 6d6475d7a1b72302e1febdec297cf4d42721e895..cc754491fae9ba9a1604a4941a67015314bb2b13 100644 --- a/src/pass/memory_optimize.cpp +++ b/src/pass/memory_optimize.cpp @@ -131,7 +131,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, DLOG << node->name; auto *var = scope->Var(node->name); auto *tensor = var->template GetMutable(); - tensor->ShareDataWith(*reuse_tensor); + tensor->ShareHolderWith(*reuse_tensor); } } }