diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index 405105771b9b1870dc0747444c113fbc9d4cdb36..3ee1603a53ab4682d3eabf7f5ceac84e243fe589 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -55,6 +55,20 @@ class TensorWrapper { if (full_reserved_) { VLOG(6) << "Fully reserved tensor: " << tensor.name(); intermidiate_tensor_ = tensor; + if (no_need_buffer_) { + if (phi::DenseTensor::classof(tensor.impl().get())) { + // Only Copy Meta + phi::DenseTensor* dense_tensor = + static_cast(tensor.impl().get()); + auto tw_dense_tensor = + std::make_shared(*dense_tensor); + tw_dense_tensor->clear(); + intermidiate_tensor_.set_impl(tw_dense_tensor); + } else { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unrecognized tensor type for no_need_buffer feature")); + } + } return; } diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index be0a937c91e4f1a99a5c12497cfda6e8e98e453b..a7b89d7a4dca9348278803a47e1cf3665bb2a53d 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -341,7 +341,11 @@ bool Tensor::is_initialized() const { return defined() && impl_->initialized(); } -void Tensor::reset() { impl_.reset(); } +void Tensor::reset() { + impl_.reset(); + autograd_meta_.reset(); + name_ = ""; +} /* Part 6: Operator overloading */