diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 844182ec3fc82ee3db46b866c13f57d88b90b48c..3fbf3560aff9519b29e2206634ca1a3fa5d2d1ba 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -263,6 +263,10 @@ template const NPUStorageProperties& DenseTensor::storage_properties() const; template const OneDNNStorageProperties& DenseTensor::storage_properties() const; #endif +bool DenseTensor::storage_properties_initialized() const { + return storage_properties_ != nullptr; +} + void DenseTensor::set_storage_properties( std::unique_ptr&& storage_properties) { storage_properties_ = std::move(storage_properties); diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index 0785af7b7037ca5c2769887c5119871ae5ab2be0..c5f38b762167f89855ee81562a975546419e97c0 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -164,6 +164,10 @@ class DenseTensor : public TensorBase, void* data(); + /// \brief Get whether the storage_properties is inited. + /// \return The init status of storage_properties. + bool storage_properties_initialized() const; + /// \brief Returns the storage_properties of the tensor. /// \return The storage_properties of the tensor. template diff --git a/paddle/phi/core/storage_properties.h b/paddle/phi/core/storage_properties.h index 908abd8d9d35d040734885bdb5dfd7f9a773ca1c..ff419387786302defbaca2e12cefc72aea8912fe 100644 --- a/paddle/phi/core/storage_properties.h +++ b/paddle/phi/core/storage_properties.h @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/utils/type_registry.h" #ifdef PADDLE_WITH_MKLDNN @@ -42,8 +43,8 @@ struct NPUStorageProperties virtual ~NPUStorageProperties() = default; static const char* name() { return "NPUStorageProperties"; } - int64_t storage_format; - int64_t storage_layout; + int64_t storage_format{-1}; + DDim storage_dims; }; // Add OneDNNStorageProperties firstly for unittest covergae @@ -76,8 +77,8 @@ static std::unique_ptr CopyStorageProperties( auto result = std::make_unique(); result->storage_format = static_cast(sp.get())->storage_format; - result->storage_layout = - static_cast(sp.get())->storage_layout; + result->storage_dims = + static_cast(sp.get())->storage_dims; return result; #ifdef PADDLE_WITH_MKLDNN } else if (OneDNNStorageProperties::classof(sp.get())) { diff --git a/paddle/phi/tests/core/test_dense_tensor.cc b/paddle/phi/tests/core/test_dense_tensor.cc index b997d8f1e76aebba2d0cafcac05a9f737d8815fe..6f08eeaefffb107bcffc9bb20571920225876e1b 100644 --- a/paddle/phi/tests/core/test_dense_tensor.cc +++ b/paddle/phi/tests/core/test_dense_tensor.cc @@ -154,13 +154,15 @@ TEST(dense_tensor, storage_properties) { EXPECT_TRUE(caught_exception); // test custom device storage properties + EXPECT_FALSE(tensor.storage_properties_initialized()); auto npu_properties = std::make_unique(); - npu_properties->storage_format = 1; - npu_properties->storage_layout = 2; + npu_properties->storage_format = 3; + npu_properties->storage_dims = {1, 1, 1, 1, 16}; tensor.set_storage_properties(std::move(npu_properties)); + EXPECT_TRUE(tensor.storage_properties_initialized()); auto get_npu_properties = tensor.storage_properties(); - CHECK_EQ(get_npu_properties.storage_format, 1); - CHECK_EQ(get_npu_properties.storage_layout, 2); + CHECK_EQ(get_npu_properties.storage_format, 3); + CHECK_EQ(get_npu_properties.storage_dims.size(), 5); // test error type storage properties #ifdef PADDLE_WITH_MKLDNN @@ -177,8 +179,8 @@ TEST(dense_tensor, storage_properties) { auto cp_tensor = tensor; auto get_cp_npu_properties = cp_tensor.storage_properties(); - CHECK_EQ(get_cp_npu_properties.storage_format, 1); - CHECK_EQ(get_cp_npu_properties.storage_layout, 2); + CHECK_EQ(get_cp_npu_properties.storage_format, 3); + CHECK_EQ(get_cp_npu_properties.storage_dims.size(), 5); } } // namespace tests