未验证 提交 1568d64f 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add more attrs into npu storiages, test=develop (#47645)

* [NPU] add more attrs into npu storiages, test=develop

* rename to storage_properties_initialized
上级 40cd5271
...@@ -266,6 +266,10 @@ template const NPUStorageProperties& DenseTensor::storage_properties() const; ...@@ -266,6 +266,10 @@ template const NPUStorageProperties& DenseTensor::storage_properties() const;
template const OneDNNStorageProperties& DenseTensor::storage_properties() const; template const OneDNNStorageProperties& DenseTensor::storage_properties() const;
#endif #endif
bool DenseTensor::storage_properties_initialized() const {
return storage_properties_ != nullptr;
}
void DenseTensor::set_storage_properties( void DenseTensor::set_storage_properties(
std::unique_ptr<StorageProperties>&& storage_properties) { std::unique_ptr<StorageProperties>&& storage_properties) {
storage_properties_ = std::move(storage_properties); storage_properties_ = std::move(storage_properties);
......
...@@ -164,6 +164,10 @@ class DenseTensor : public TensorBase, ...@@ -164,6 +164,10 @@ class DenseTensor : public TensorBase,
void* data(); void* data();
/// \brief Get whether the storage_properties is inited.
/// \return The init status of storage_properties.
bool storage_properties_initialized() const;
/// \brief Returns the storage_properties of the tensor. /// \brief Returns the storage_properties of the tensor.
/// \return The storage_properties of the tensor. /// \return The storage_properties of the tensor.
template <typename DeviceT> template <typename DeviceT>
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/type_registry.h" #include "paddle/phi/core/utils/type_registry.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -42,8 +42,10 @@ struct NPUStorageProperties ...@@ -42,8 +42,10 @@ struct NPUStorageProperties
virtual ~NPUStorageProperties() = default; virtual ~NPUStorageProperties() = default;
static const char* name() { return "NPUStorageProperties"; } static const char* name() { return "NPUStorageProperties"; }
int64_t storage_format; int64_t origin_format{-1};
int64_t storage_layout; int64_t storage_format{-1};
DDim origin_dims;
DDim storage_dims;
}; };
// Add OneDNNStorageProperties firstly for unittest covergae // Add OneDNNStorageProperties firstly for unittest covergae
...@@ -74,10 +76,14 @@ static std::unique_ptr<StorageProperties> CopyStorageProperties( ...@@ -74,10 +76,14 @@ static std::unique_ptr<StorageProperties> CopyStorageProperties(
if (sp) { if (sp) {
if (NPUStorageProperties::classof(sp.get())) { if (NPUStorageProperties::classof(sp.get())) {
auto result = std::make_unique<NPUStorageProperties>(); auto result = std::make_unique<NPUStorageProperties>();
result->origin_format =
static_cast<NPUStorageProperties*>(sp.get())->origin_format;
result->storage_format = result->storage_format =
static_cast<NPUStorageProperties*>(sp.get())->storage_format; static_cast<NPUStorageProperties*>(sp.get())->storage_format;
result->storage_layout = result->origin_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_layout; static_cast<NPUStorageProperties*>(sp.get())->origin_dims;
result->storage_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
return result; return result;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
} else if (OneDNNStorageProperties::classof(sp.get())) { } else if (OneDNNStorageProperties::classof(sp.get())) {
......
...@@ -154,13 +154,19 @@ TEST(dense_tensor, storage_properties) { ...@@ -154,13 +154,19 @@ TEST(dense_tensor, storage_properties) {
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
// test custom device storage properties // test custom device storage properties
EXPECT_FALSE(tensor.storage_properties_initialized());
auto npu_properties = std::make_unique<NPUStorageProperties>(); auto npu_properties = std::make_unique<NPUStorageProperties>();
npu_properties->storage_format = 1; npu_properties->origin_format = 0;
npu_properties->storage_layout = 2; npu_properties->storage_format = 3;
npu_properties->origin_dims = {1, 8, 5, 5};
npu_properties->storage_dims = {1, 1, 5, 5, 16};
tensor.set_storage_properties(std::move(npu_properties)); tensor.set_storage_properties(std::move(npu_properties));
EXPECT_TRUE(tensor.storage_properties_initialized());
auto get_npu_properties = tensor.storage_properties<NPUStorageProperties>(); auto get_npu_properties = tensor.storage_properties<NPUStorageProperties>();
CHECK_EQ(get_npu_properties.storage_format, 1); CHECK_EQ(get_npu_properties.origin_format, 0);
CHECK_EQ(get_npu_properties.storage_layout, 2); CHECK_EQ(get_npu_properties.storage_format, 3);
CHECK_EQ(get_npu_properties.origin_dims.size(), 4);
CHECK_EQ(get_npu_properties.storage_dims.size(), 5);
// test error type storage properties // test error type storage properties
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -177,8 +183,10 @@ TEST(dense_tensor, storage_properties) { ...@@ -177,8 +183,10 @@ TEST(dense_tensor, storage_properties) {
auto cp_tensor = tensor; auto cp_tensor = tensor;
auto get_cp_npu_properties = auto get_cp_npu_properties =
cp_tensor.storage_properties<NPUStorageProperties>(); cp_tensor.storage_properties<NPUStorageProperties>();
CHECK_EQ(get_cp_npu_properties.storage_format, 1); CHECK_EQ(get_cp_npu_properties.origin_format, 0);
CHECK_EQ(get_cp_npu_properties.storage_layout, 2); CHECK_EQ(get_cp_npu_properties.storage_format, 3);
CHECK_EQ(get_cp_npu_properties.origin_dims.size(), 4);
CHECK_EQ(get_cp_npu_properties.storage_dims.size(), 5);
} }
} // namespace tests } // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册