未验证 提交 1568d64f 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add more attrs into npu storiages, test=develop (#47645)

* [NPU] add more attrs into npu storiages, test=develop

* rename to storage_properties_initialized
上级 40cd5271
......@@ -266,6 +266,10 @@ template const NPUStorageProperties& DenseTensor::storage_properties() const;
template const OneDNNStorageProperties& DenseTensor::storage_properties() const;
#endif
bool DenseTensor::storage_properties_initialized() const {
return storage_properties_ != nullptr;
}
void DenseTensor::set_storage_properties(
std::unique_ptr<StorageProperties>&& storage_properties) {
storage_properties_ = std::move(storage_properties);
......
......@@ -164,6 +164,10 @@ class DenseTensor : public TensorBase,
void* data();
/// \brief Get whether the storage_properties is inited.
/// \return The init status of storage_properties.
bool storage_properties_initialized() const;
/// \brief Returns the storage_properties of the tensor.
/// \return The storage_properties of the tensor.
template <typename DeviceT>
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/type_registry.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -42,8 +42,10 @@ struct NPUStorageProperties
virtual ~NPUStorageProperties() = default;
static const char* name() { return "NPUStorageProperties"; }
int64_t storage_format;
int64_t storage_layout;
int64_t origin_format{-1};
int64_t storage_format{-1};
DDim origin_dims;
DDim storage_dims;
};
// Add OneDNNStorageProperties firstly for unittest covergae
......@@ -74,10 +76,14 @@ static std::unique_ptr<StorageProperties> CopyStorageProperties(
if (sp) {
if (NPUStorageProperties::classof(sp.get())) {
auto result = std::make_unique<NPUStorageProperties>();
result->origin_format =
static_cast<NPUStorageProperties*>(sp.get())->origin_format;
result->storage_format =
static_cast<NPUStorageProperties*>(sp.get())->storage_format;
result->storage_layout =
static_cast<NPUStorageProperties*>(sp.get())->storage_layout;
result->origin_dims =
static_cast<NPUStorageProperties*>(sp.get())->origin_dims;
result->storage_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
return result;
#ifdef PADDLE_WITH_MKLDNN
} else if (OneDNNStorageProperties::classof(sp.get())) {
......
......@@ -154,13 +154,19 @@ TEST(dense_tensor, storage_properties) {
EXPECT_TRUE(caught_exception);
// test custom device storage properties
EXPECT_FALSE(tensor.storage_properties_initialized());
auto npu_properties = std::make_unique<NPUStorageProperties>();
npu_properties->storage_format = 1;
npu_properties->storage_layout = 2;
npu_properties->origin_format = 0;
npu_properties->storage_format = 3;
npu_properties->origin_dims = {1, 8, 5, 5};
npu_properties->storage_dims = {1, 1, 5, 5, 16};
tensor.set_storage_properties(std::move(npu_properties));
EXPECT_TRUE(tensor.storage_properties_initialized());
auto get_npu_properties = tensor.storage_properties<NPUStorageProperties>();
CHECK_EQ(get_npu_properties.storage_format, 1);
CHECK_EQ(get_npu_properties.storage_layout, 2);
CHECK_EQ(get_npu_properties.origin_format, 0);
CHECK_EQ(get_npu_properties.storage_format, 3);
CHECK_EQ(get_npu_properties.origin_dims.size(), 4);
CHECK_EQ(get_npu_properties.storage_dims.size(), 5);
// test error type storage properties
#ifdef PADDLE_WITH_MKLDNN
......@@ -177,8 +183,10 @@ TEST(dense_tensor, storage_properties) {
auto cp_tensor = tensor;
auto get_cp_npu_properties =
cp_tensor.storage_properties<NPUStorageProperties>();
CHECK_EQ(get_cp_npu_properties.storage_format, 1);
CHECK_EQ(get_cp_npu_properties.storage_layout, 2);
CHECK_EQ(get_cp_npu_properties.origin_format, 0);
CHECK_EQ(get_cp_npu_properties.storage_format, 3);
CHECK_EQ(get_cp_npu_properties.origin_dims.size(), 4);
CHECK_EQ(get_cp_npu_properties.storage_dims.size(), 5);
}
} // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册