diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 8a2d0e8a46bd460b104d08640d637ca2ab25dba2..8fa1679476abdb805a8dc765b9119906b2d9d33d 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -266,6 +266,10 @@ template const NPUStorageProperties& DenseTensor::storage_properties() const; template const OneDNNStorageProperties& DenseTensor::storage_properties() const; #endif +bool DenseTensor::storage_properties_initialized() const { + return storage_properties_ != nullptr; +} + void DenseTensor::set_storage_properties( std::unique_ptr&& storage_properties) { storage_properties_ = std::move(storage_properties); diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index e0d620ac3a53e05704ec4995370e009bb78b4644..a07ab2b06fe4b1f6331c3c6345b2633c85219e32 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -164,6 +164,10 @@ class DenseTensor : public TensorBase, void* data(); + /// \brief Get whether the storage_properties is inited. + /// \return The init status of storage_properties. + bool storage_properties_initialized() const; + /// \brief Returns the storage_properties of the tensor. /// \return The storage_properties of the tensor. template diff --git a/paddle/phi/core/storage_properties.h b/paddle/phi/core/storage_properties.h index 908abd8d9d35d040734885bdb5dfd7f9a773ca1c..82c2a55bcd2904541efd6160aa49e7ea2a187f74 100644 --- a/paddle/phi/core/storage_properties.h +++ b/paddle/phi/core/storage_properties.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include - +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/utils/type_registry.h" #ifdef PADDLE_WITH_MKLDNN @@ -42,8 +42,10 @@ struct NPUStorageProperties virtual ~NPUStorageProperties() = default; static const char* name() { return "NPUStorageProperties"; } - int64_t storage_format; - int64_t storage_layout; + int64_t origin_format{-1}; + int64_t storage_format{-1}; + DDim origin_dims; + DDim storage_dims; }; // Add OneDNNStorageProperties firstly for unittest covergae @@ -74,10 +76,14 @@ static std::unique_ptr CopyStorageProperties( if (sp) { if (NPUStorageProperties::classof(sp.get())) { auto result = std::make_unique(); + result->origin_format = + static_cast(sp.get())->origin_format; result->storage_format = static_cast(sp.get())->storage_format; - result->storage_layout = - static_cast(sp.get())->storage_layout; + result->origin_dims = + static_cast(sp.get())->origin_dims; + result->storage_dims = + static_cast(sp.get())->storage_dims; return result; #ifdef PADDLE_WITH_MKLDNN } else if (OneDNNStorageProperties::classof(sp.get())) { diff --git a/paddle/phi/tests/core/test_dense_tensor.cc b/paddle/phi/tests/core/test_dense_tensor.cc index b997d8f1e76aebba2d0cafcac05a9f737d8815fe..cb66fe7259a1d7b2a9570805e2eea96dab01f119 100644 --- a/paddle/phi/tests/core/test_dense_tensor.cc +++ b/paddle/phi/tests/core/test_dense_tensor.cc @@ -154,13 +154,19 @@ TEST(dense_tensor, storage_properties) { EXPECT_TRUE(caught_exception); // test custom device storage properties + EXPECT_FALSE(tensor.storage_properties_initialized()); auto npu_properties = std::make_unique(); - npu_properties->storage_format = 1; - npu_properties->storage_layout = 2; + npu_properties->origin_format = 0; + npu_properties->storage_format = 3; + npu_properties->origin_dims = {1, 8, 5, 5}; + npu_properties->storage_dims = {1, 1, 5, 5, 16}; tensor.set_storage_properties(std::move(npu_properties)); + EXPECT_TRUE(tensor.storage_properties_initialized()); auto get_npu_properties = tensor.storage_properties(); - CHECK_EQ(get_npu_properties.storage_format, 1); - CHECK_EQ(get_npu_properties.storage_layout, 2); + CHECK_EQ(get_npu_properties.origin_format, 0); + CHECK_EQ(get_npu_properties.storage_format, 3); + CHECK_EQ(get_npu_properties.origin_dims.size(), 4); + CHECK_EQ(get_npu_properties.storage_dims.size(), 5); // test error type storage properties #ifdef PADDLE_WITH_MKLDNN @@ -177,8 +183,10 @@ TEST(dense_tensor, storage_properties) { auto cp_tensor = tensor; auto get_cp_npu_properties = cp_tensor.storage_properties(); - CHECK_EQ(get_cp_npu_properties.storage_format, 1); - CHECK_EQ(get_cp_npu_properties.storage_layout, 2); + CHECK_EQ(get_cp_npu_properties.origin_format, 0); + CHECK_EQ(get_cp_npu_properties.storage_format, 3); + CHECK_EQ(get_cp_npu_properties.origin_dims.size(), 4); + CHECK_EQ(get_cp_npu_properties.storage_dims.size(), 5); } } // namespace tests