diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 1ca6e2ce0bb9a31365a176b4ca7b595923acb19f..6d97dc7657f00616c8970e86b03f09f35eaa4a0f 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -14,18 +14,46 @@ limitations under the License. */ #include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/api/include/context_pool.h" -#include "paddle/phi/core/compat/convert_utils.h" #ifdef _MSC_VER #include #endif +#include "paddle/phi/api/include/context_pool.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/string_tensor_utils.h" +#include "paddle/phi/core/tensor_utils.h" + namespace paddle { namespace experimental { namespace detail { +// We need judge whether the allocation is nullptr, +// whether the allocation is initialized, wo we need GetHolder method +bool HasAllocation(const phi::TensorBase& t) { + if (phi::DenseTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t)) != nullptr; + } else if (phi::SelectedRows::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t).value()) != nullptr; + } else if (phi::SparseCsrTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t) + .non_zero_elements()) != nullptr; + } else if (phi::SparseCooTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t) + .non_zero_elements()) != nullptr; + } else if (phi::StringTensor::classof(&t)) { + return phi::StringTensorUtils::GetHolder( + static_cast(t)) != nullptr; + } else { + return false; + } +} + BackendSet GetTensorBackendSet(const phi::TensorBase& t) { - if (t.initialized()) { + if (HasAllocation(t)) { BackendSet backend_set(phi::TransToPhiBackend(t.place())); switch (t.layout()) { case DataLayout::MKLDNN: diff --git a/paddle/phi/core/string_tensor_utils.h b/paddle/phi/core/string_tensor_utils.h index c1b0d09647d91c0529e0db952937d5585be9e9d9..777a24c9adfe15bf3dfafeda57c734b6c1c9a665 100644 --- a/paddle/phi/core/string_tensor_utils.h +++ b/paddle/phi/core/string_tensor_utils.h @@ -23,6 +23,11 @@ class StringTensorUtils { static StringTensorMeta* GetMutableMeta(StringTensor* tensor) { return &(tensor->meta_); } + + static const std::shared_ptr& GetHolder( + const StringTensor& tensor) { + return tensor.holder_; + } }; } // namespace phi diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 676a590ecbce23a107bcc891c37ac69406854035..abf8aeff4d3ab047809bad8ba902075824cf263e 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -25,6 +25,11 @@ class DenseTensorUtils { return &(tensor->meta_); } + static const std::shared_ptr& GetHolder( + const DenseTensor& tensor) { + return tensor.holder_; + } + static DenseTensor Slice(const DenseTensor& tensor, int64_t begin_idx, int64_t end_idx) {