diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index edd5cde9386307cd0e33b3a9fb0c5a1d7110cf81..f304268bedf45d38315c5d47bd5278341736559a 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -325,9 +325,7 @@ void SharesStorageBase(pten::DenseTensor* src, paddle::framework::Tensor* dst) { platform::errors::InvalidArgument( "The destination Tensor is nullptr when move allocation.")); dst->Resize(src->dims()); - auto* storage = static_cast( - pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src)); - dst->ResetHolderWithType(storage->GetAllocation(), + dst->ResetHolderWithType(src->Holder(), pten::TransToProtoVarType(src->dtype())); dst->set_offset(src->meta().offset); } @@ -345,19 +343,7 @@ void ReMakePtenDenseTensorBase(const paddle::framework::Tensor& src, meta->dtype = pten::TransToPtenDataType(src.type()); meta->layout = src.layout(); meta->offset = src.offset(); - - auto* shared_storage = static_cast( - pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst)); - PADDLE_ENFORCE_NOT_NULL( - shared_storage, - platform::errors::NotFound( - "Target DenseTensor's shared storage is nullptr.")); - - PADDLE_ENFORCE_EQ(src.IsInitialized(), - true, - paddle::platform::errors::InvalidArgument( - "Source Tensor is not initialized.")); - shared_storage->ResetAllocation(src.Holder()); + dst->ResetHolder(src.Holder()); } void ReMakePtenDenseTensor(const paddle::framework::Tensor& src, @@ -378,19 +364,12 @@ void ReMakePtenDenseTensorByArgDefBase(const paddle::framework::Tensor& src, meta->layout = src.layout(); meta->offset = src.offset(); - auto* shared_storage = static_cast( - pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst)); - PADDLE_ENFORCE_NOT_NULL( - shared_storage, - platform::errors::NotFound( - "Target DenseTensor's shared storage is nullptr.")); - if (src.IsInitialized() && src.place() == pten::TransToFluidPlace(arg_def.backend)) { - shared_storage->ResetAllocation(src.Holder()); + dst->ResetHolder(src.Holder()); } else { - shared_storage->ResetAllocationPlace( - pten::TransToFluidPlace(arg_def.backend)); + // This does not affect the correctness, and will be modified immediately. + // dst->mutable_data(pten::TransToFluidPlace(arg_def.backend)); } } @@ -481,14 +460,10 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src, tensor->Resize(src->dims()); SetLoD(tensor->mutable_lod(), src->lod()); - // here dynamic_cast is slow - auto* storage = static_cast( - pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src)); - if (!tensor->IsInitialized() || (tensor->IsInitialized() && - !IsSameAllocation(tensor->Holder(), storage->GetAllocation()))) { - tensor->ResetHolderWithType(std::move(storage->GetAllocation()), dtype); + !IsSameAllocation(tensor->Holder(), src->Holder()))) { + tensor->ResetHolderWithType(std::move(src->Holder()), dtype); } else { // Even the pten tensor and Variable have the same Alloctation (both have // the same pointer address, same size and same place) @@ -502,10 +477,8 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src, auto dtype = pten::TransToProtoVarType(src->dtype()); if (!tensor->value().IsInitialized()) { - auto storage = dynamic_cast( - pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src)); - tensor->mutable_value()->ResetHolderWithType( - std::move(storage->GetAllocation()), dtype); + tensor->mutable_value()->ResetHolderWithType(std::move(src->Holder()), + dtype); } } else { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/pten/core/compat_utils.h b/paddle/pten/core/compat_utils.h index 0bd82080ddebcd35261f4d8c09f22fa149247cf7..46e53e3997cc1930e2c41c8e6602f7171413718f 100644 --- a/paddle/pten/core/compat_utils.h +++ b/paddle/pten/core/compat_utils.h @@ -31,10 +31,6 @@ namespace pten { class CompatibleDenseTensorUtils { public: - static Storage* UnsafeGetMutableStorage(DenseTensor* tensor) { - return tensor->storage_.get(); - } - static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) { return &(tensor->meta_); } @@ -42,10 +38,7 @@ class CompatibleDenseTensorUtils { // only can deal with SharedStorage now static void ClearStorage(DenseTensor* tensor) { // use static_cast to improve performance, replace by dynamic_cast later - if (tensor->storage_ != nullptr) { - static_cast(tensor->storage_.get()) - ->Reset(); - } + tensor->MoveMemoryHolder(); } static DenseTensor Slice(const DenseTensor& tensor,