From ad4193fe957fe2eccbc2c9fd36b1f8395e2ecf1d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 7 Apr 2022 13:47:13 +0800 Subject: [PATCH] fix get tensor backend set bug (#41478) --- paddle/phi/api/lib/kernel_dispatch.cc | 34 ++++++++++++++++++++++++--- paddle/phi/core/string_tensor_utils.h | 5 ++++ paddle/phi/core/tensor_utils.h | 5 ++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 1ca6e2ce0b..6d97dc7657 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -14,18 +14,46 @@ limitations under the License. */ #include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/api/include/context_pool.h" -#include "paddle/phi/core/compat/convert_utils.h" #ifdef _MSC_VER #include #endif +#include "paddle/phi/api/include/context_pool.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/string_tensor_utils.h" +#include "paddle/phi/core/tensor_utils.h" + namespace paddle { namespace experimental { namespace detail { +// We need judge whether the allocation is nullptr, +// whether the allocation is initialized, wo we need GetHolder method +bool HasAllocation(const phi::TensorBase& t) { + if (phi::DenseTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t)) != nullptr; + } else if (phi::SelectedRows::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t).value()) != nullptr; + } else if (phi::SparseCsrTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t) + .non_zero_elements()) != nullptr; + } else if (phi::SparseCooTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t) + .non_zero_elements()) != nullptr; + } else if (phi::StringTensor::classof(&t)) { + return phi::StringTensorUtils::GetHolder( + static_cast(t)) != nullptr; + } else { + return false; + } +} + BackendSet GetTensorBackendSet(const phi::TensorBase& t) { - if (t.initialized()) { + if (HasAllocation(t)) { BackendSet backend_set(phi::TransToPhiBackend(t.place())); switch (t.layout()) { case DataLayout::MKLDNN: diff --git a/paddle/phi/core/string_tensor_utils.h b/paddle/phi/core/string_tensor_utils.h index c1b0d09647..777a24c9ad 100644 --- a/paddle/phi/core/string_tensor_utils.h +++ b/paddle/phi/core/string_tensor_utils.h @@ -23,6 +23,11 @@ class StringTensorUtils { static StringTensorMeta* GetMutableMeta(StringTensor* tensor) { return &(tensor->meta_); } + + static const std::shared_ptr& GetHolder( + const StringTensor& tensor) { + return tensor.holder_; + } }; } // namespace phi diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 676a590ecb..abf8aeff4d 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -25,6 +25,11 @@ class DenseTensorUtils { return &(tensor->meta_); } + static const std::shared_ptr& GetHolder( + const DenseTensor& tensor) { + return tensor.holder_; + } + static DenseTensor Slice(const DenseTensor& tensor, int64_t begin_idx, int64_t end_idx) { -- GitLab