diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 8fa1679476abdb805a8dc765b9119906b2d9d33d..8a2d0e8a46bd460b104d08640d637ca2ab25dba2 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -266,10 +266,6 @@ template const NPUStorageProperties& DenseTensor::storage_properties() const; template const OneDNNStorageProperties& DenseTensor::storage_properties() const; #endif -bool DenseTensor::storage_properties_initialized() const { - return storage_properties_ != nullptr; -} - void DenseTensor::set_storage_properties( std::unique_ptr&& storage_properties) { storage_properties_ = std::move(storage_properties); diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index a07ab2b06fe4b1f6331c3c6345b2633c85219e32..e0d620ac3a53e05704ec4995370e009bb78b4644 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -164,10 +164,6 @@ class DenseTensor : public TensorBase, void* data(); - /// \brief Get whether the storage_properties is inited. - /// \return The init status of storage_properties. - bool storage_properties_initialized() const; - /// \brief Returns the storage_properties of the tensor. /// \return The storage_properties of the tensor. template diff --git a/paddle/phi/core/storage_properties.h b/paddle/phi/core/storage_properties.h index 82c2a55bcd2904541efd6160aa49e7ea2a187f74..908abd8d9d35d040734885bdb5dfd7f9a773ca1c 100644 --- a/paddle/phi/core/storage_properties.h +++ b/paddle/phi/core/storage_properties.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/phi/core/ddim.h" + #include "paddle/phi/core/utils/type_registry.h" #ifdef PADDLE_WITH_MKLDNN @@ -42,10 +42,8 @@ struct NPUStorageProperties virtual ~NPUStorageProperties() = default; static const char* name() { return "NPUStorageProperties"; } - int64_t origin_format{-1}; - int64_t storage_format{-1}; - DDim origin_dims; - DDim storage_dims; + int64_t storage_format; + int64_t storage_layout; }; // Add OneDNNStorageProperties firstly for unittest covergae @@ -76,14 +74,10 @@ static std::unique_ptr CopyStorageProperties( if (sp) { if (NPUStorageProperties::classof(sp.get())) { auto result = std::make_unique(); - result->origin_format = - static_cast(sp.get())->origin_format; result->storage_format = static_cast(sp.get())->storage_format; - result->origin_dims = - static_cast(sp.get())->origin_dims; - result->storage_dims = - static_cast(sp.get())->storage_dims; + result->storage_layout = + static_cast(sp.get())->storage_layout; return result; #ifdef PADDLE_WITH_MKLDNN } else if (OneDNNStorageProperties::classof(sp.get())) { diff --git a/paddle/phi/tests/core/test_dense_tensor.cc b/paddle/phi/tests/core/test_dense_tensor.cc index cb66fe7259a1d7b2a9570805e2eea96dab01f119..b997d8f1e76aebba2d0cafcac05a9f737d8815fe 100644 --- a/paddle/phi/tests/core/test_dense_tensor.cc +++ b/paddle/phi/tests/core/test_dense_tensor.cc @@ -154,19 +154,13 @@ TEST(dense_tensor, storage_properties) { EXPECT_TRUE(caught_exception); // test custom device storage properties - EXPECT_FALSE(tensor.storage_properties_initialized()); auto npu_properties = std::make_unique(); - npu_properties->origin_format = 0; - npu_properties->storage_format = 3; - npu_properties->origin_dims = {1, 8, 5, 5}; - npu_properties->storage_dims = {1, 1, 5, 5, 16}; + npu_properties->storage_format = 1; + npu_properties->storage_layout = 2; tensor.set_storage_properties(std::move(npu_properties)); - EXPECT_TRUE(tensor.storage_properties_initialized()); auto get_npu_properties = tensor.storage_properties(); - CHECK_EQ(get_npu_properties.origin_format, 0); - CHECK_EQ(get_npu_properties.storage_format, 3); - CHECK_EQ(get_npu_properties.origin_dims.size(), 4); - CHECK_EQ(get_npu_properties.storage_dims.size(), 5); + CHECK_EQ(get_npu_properties.storage_format, 1); + CHECK_EQ(get_npu_properties.storage_layout, 2); // test error type storage properties #ifdef PADDLE_WITH_MKLDNN @@ -183,10 +177,8 @@ TEST(dense_tensor, storage_properties) { auto cp_tensor = tensor; auto get_cp_npu_properties = cp_tensor.storage_properties(); - CHECK_EQ(get_cp_npu_properties.origin_format, 0); - CHECK_EQ(get_cp_npu_properties.storage_format, 3); - CHECK_EQ(get_cp_npu_properties.origin_dims.size(), 4); - CHECK_EQ(get_cp_npu_properties.storage_dims.size(), 5); + CHECK_EQ(get_cp_npu_properties.storage_format, 1); + CHECK_EQ(get_cp_npu_properties.storage_layout, 2); } } // namespace tests