diff --git a/paddle/pten/api/include/tensor.h b/paddle/pten/api/include/tensor.h index 935c7d8e325d0d0cc0dbdf5ca321420001999cda..4a8a593561ad448561d5f06d6ccb49059fef942b 100644 --- a/paddle/pten/api/include/tensor.h +++ b/paddle/pten/api/include/tensor.h @@ -204,6 +204,14 @@ class PADDLE_API Tensor final { */ DataLayout layout() const; + /** + * @brief Determine whether tensor is DenseTensor + * + * @return true + * @return false + */ + bool is_dense_tensor() const; + /* Part 3: Device and Backend methods */ /** diff --git a/paddle/pten/api/lib/tensor.cc b/paddle/pten/api/lib/tensor.cc index 6ecc46ca8b53f140d1e8b286f9f56e6596ca4412..74451d00e554614a5b584541b157c4efc95ce79d 100644 --- a/paddle/pten/api/lib/tensor.cc +++ b/paddle/pten/api/lib/tensor.cc @@ -58,15 +58,6 @@ limitations under the License. */ namespace paddle { namespace experimental { -namespace detail { - -inline bool IsDenseTensor( - const std::shared_ptr &tensor_impl) { - return tensor_impl->type_info().name() == "DenseTensor"; -} - -} // namespace detail - // declare cast api Tensor cast(const Tensor &x, DataType out_dtype); @@ -118,7 +109,7 @@ void Tensor::reshape(const std::vector &shape) { "reason: `reshape` means changing the tensor shape without " "touching underlying data, this requires the total size of " "the tensor to remain constant."; - if (detail::IsDenseTensor(impl_)) { + if (is_dense_tensor()) { std::dynamic_pointer_cast(impl_)->set_meta( pten::DenseTensorMeta(dtype(), framework::make_ddim(shape))); } else { @@ -133,6 +124,10 @@ DataType Tensor::type() const { return impl_->dtype(); } DataLayout Tensor::layout() const { return impl_->layout(); } +bool Tensor::is_dense_tensor() const { + return pten::DenseTensor::classof(impl_.get()); +} + /* Part 3: Device and Backend methods */ PlaceType Tensor::place() const { @@ -153,7 +148,7 @@ bool Tensor::is_cuda() const { template T *Tensor::mutable_data() { - if (detail::IsDenseTensor(impl_)) { + if (is_dense_tensor()) { return std::dynamic_pointer_cast(impl_) ->mutable_data(); } @@ -209,7 +204,7 @@ Tensor::mutable_data(const PlaceType &place); template const T *Tensor::data() const { - if (detail::IsDenseTensor(impl_)) { + if (is_dense_tensor()) { return std::dynamic_pointer_cast(impl_)->data(); } return nullptr; @@ -259,7 +254,7 @@ Tensor::data(); // TODO(chenweihang): replace slice impl by API Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const { - if (detail::IsDenseTensor(impl_)) { + if (is_dense_tensor()) { return Tensor(std::make_shared( std::move(pten::CompatibleDenseTensorUtils::Slice( std::dynamic_pointer_cast(impl_).get(), diff --git a/paddle/pten/tests/api/test_pten_tensor.cc b/paddle/pten/tests/api/test_pten_tensor.cc index bffc1b8d89fe023957813e158a81e352b6549b78..a28f7ca2ca2e685ffebfd9ceb9906e245fb80fce 100644 --- a/paddle/pten/tests/api/test_pten_tensor.cc +++ b/paddle/pten/tests/api/test_pten_tensor.cc @@ -205,6 +205,11 @@ void TestInitilized() { } } +void TestJudgeTensorType() { + experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1}); + CHECK(test_tensor.is_dense_tensor() == true); +} + TEST(PtenTensor, All) { VLOG(2) << "TestCopy"; GroupTestCopy(); @@ -220,6 +225,8 @@ TEST(PtenTensor, All) { GroupTestCast(); VLOG(2) << "TestInitilized"; TestInitilized(); + VLOG(2) << "TestJudgeTensorType"; + TestJudgeTensorType(); } } // namespace tests