From 5c10be4f440d05a46134b922f891837ddc6e77a9 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Sat, 20 May 2023 06:12:39 +0800 Subject: [PATCH] [IR] Add types and attributes to builtin and pd dialect (#53953) * add types and attributes * remove some const_cast * refine code --- paddle/fluid/dialect/pd_attribute.cc | 32 +++++ paddle/fluid/dialect/pd_attribute.h | 95 +++++++++++++ paddle/fluid/dialect/pd_attribute_storage.h | 141 ++++++++++++++++++++ paddle/fluid/dialect/pd_dialect.cc | 4 +- paddle/fluid/dialect/pd_type.h | 2 +- paddle/ir/builtin_attribute.cc | 14 ++ paddle/ir/builtin_attribute.h | 73 ++++++++-- paddle/ir/builtin_attribute_storage.h | 81 ++++++++++- paddle/ir/builtin_type.cc | 20 +++ paddle/ir/builtin_type.h | 81 ++++++++--- paddle/ir/builtin_type_storage.h | 78 +++++++++++ test/cpp/ir/type_test.cc | 10 ++ 12 files changed, 595 insertions(+), 36 deletions(-) create mode 100644 paddle/fluid/dialect/pd_attribute.cc create mode 100644 paddle/fluid/dialect/pd_attribute.h create mode 100644 paddle/fluid/dialect/pd_attribute_storage.h create mode 100644 paddle/ir/builtin_type.cc create mode 100644 paddle/ir/builtin_type_storage.h diff --git a/paddle/fluid/dialect/pd_attribute.cc b/paddle/fluid/dialect/pd_attribute.cc new file mode 100644 index 00000000000..49e6865160d --- /dev/null +++ b/paddle/fluid/dialect/pd_attribute.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/dialect/pd_attribute.h" + +namespace paddle { +namespace dialect { +phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); } + +phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); } + +phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); } + +phi::Place PlaceAttribute::data() const { return storage()->GetAsKey(); } + +phi::DataLayout DataLayoutAttribute::data() const { + return storage()->GetAsKey(); +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/dialect/pd_attribute.h b/paddle/fluid/dialect/pd_attribute.h new file mode 100644 index 00000000000..75eed82dfc4 --- /dev/null +++ b/paddle/fluid/dialect/pd_attribute.h @@ -0,0 +1,95 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/dialect/pd_attribute_storage.h" +#include "paddle/ir/attribute.h" + +namespace paddle { +namespace dialect { +#define GET_PD_DIALECT_ATTRIBUTE_LIST \ + IntArrayAttribute, ScalarAttribute, DataTypeAttribute, PlaceAttribute, \ + DataLayoutAttribute + +class IntArrayAttribute : public ir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(IntArrayAttribute, + IntArrayAttributeStorage); + + bool operator<(const IntArrayAttribute &right) const { + return storage() < right.storage(); + } + + phi::IntArray data() const; +}; + +class ScalarAttribute : public ir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ScalarAttribute, ScalarAttributeStorage); + + bool operator<(const ScalarAttribute &right) const { + return storage() < right.storage(); + } + + phi::Scalar data() const; +}; + +class DataTypeAttribute : public ir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DataTypeAttribute, + DataTypeAttributeStorage); + + bool operator<(const DataTypeAttribute &right) const { + return storage() < right.storage(); + } + + phi::DataType data() const; +}; + +class PlaceAttribute : public ir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PlaceAttribute, PlaceAttributeStorage); + + bool operator<(const PlaceAttribute &right) const { + return storage() < right.storage(); + } + + phi::Place data() const; +}; + +class DataLayoutAttribute : public ir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DataLayoutAttribute, + DataLayoutAttributeStorage); + + bool operator<(const DataLayoutAttribute &right) const { + return storage() < right.storage(); + } + + phi::DataLayout data() const; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/dialect/pd_attribute_storage.h b/paddle/fluid/dialect/pd_attribute_storage.h new file mode 100644 index 00000000000..352dcc8b0e4 --- /dev/null +++ b/paddle/fluid/dialect/pd_attribute_storage.h @@ -0,0 +1,141 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/attribute.h" +#include "paddle/ir/utils.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/common/scalar.h" + +namespace paddle { +namespace dialect { +struct IntArrayAttributeStorage : public ir::AttributeStorage { + using ParamKey = phi::IntArray; + + explicit IntArrayAttributeStorage(const ParamKey &key) { data_ = key; } + + static IntArrayAttributeStorage *Construct(ParamKey key) { + return new IntArrayAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + size_t hash_value = 0; + hash_value = + ir::hash_combine(hash_value, std::hash()(key.FromTensor())); + for (auto value : key.GetData()) { + hash_value = ir::hash_combine(hash_value, std::hash()(value)); + } + return hash_value; + } + + bool operator==(const ParamKey &key) const { + return (data_.GetData() == key.GetData()) && + (data_.FromTensor() == key.FromTensor()); + } + + ParamKey GetAsKey() const { return ParamKey(data_); } + + private: + phi::IntArray data_; +}; + +struct ScalarAttributeStorage : public ir::AttributeStorage { + using ParamKey = phi::Scalar; + + explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; } + + static ScalarAttributeStorage *Construct(ParamKey key) { + return new ScalarAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + return ir::hash_combine(std::hash()(key.ToString()), + std::hash()(key.FromTensor())); + } + + bool operator==(const ParamKey &key) const { return data_ == key; } + + ParamKey GetAsKey() const { return ParamKey(data_); } + + private: + phi::Scalar data_; +}; + +struct DataTypeAttributeStorage : public ir::AttributeStorage { + using ParamKey = phi::DataType; + + explicit DataTypeAttributeStorage(const ParamKey &key) { data_ = key; } + + static DataTypeAttributeStorage *Construct(ParamKey key) { + return new DataTypeAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + return std::hash()(key); + } + + bool operator==(const ParamKey &key) const { return data_ == key; } + + ParamKey GetAsKey() const { return data_; } + + private: + phi::DataType data_; +}; + +struct PlaceAttributeStorage : public ir::AttributeStorage { + using ParamKey = phi::Place; + + explicit PlaceAttributeStorage(const ParamKey &key) { data_ = key; } + + static PlaceAttributeStorage *Construct(ParamKey key) { + return new PlaceAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { return key.HashValue(); } + + bool operator==(const ParamKey &key) const { return data_ == key; } + + ParamKey GetAsKey() const { return data_; } + + private: + phi::Place data_; +}; + +struct DataLayoutAttributeStorage : public ir::AttributeStorage { + using ParamKey = phi::DataLayout; + + explicit DataLayoutAttributeStorage(const ParamKey &key) { data_ = key; } + + static DataLayoutAttributeStorage *Construct(ParamKey key) { + return new DataLayoutAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + return std::hash()(key); + } + + bool operator==(const ParamKey &key) const { return data_ == key; } + + ParamKey GetAsKey() const { return data_; } + + private: + phi::DataLayout data_; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index df98a997880..4439110adcb 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/dialect/pd_attribute.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/utils.h" #include "paddle/fluid/framework/convert_utils.h" @@ -89,7 +90,8 @@ PaddleDialect::PaddleDialect(ir::IrContext* context) } void PaddleDialect::initialize() { - RegisterTypes(); + RegisterTypes(); + RegisterAttributes(); RegisterInterfaces(); } diff --git a/paddle/fluid/dialect/pd_type.h b/paddle/fluid/dialect/pd_type.h index 0d4f8aa27b4..b644f12c1f8 100644 --- a/paddle/fluid/dialect/pd_type.h +++ b/paddle/fluid/dialect/pd_type.h @@ -19,7 +19,7 @@ namespace paddle { namespace dialect { -#define GET_PADDLE_TYPE_LIST paddle::dialect::DenseTensorType +#define GET_PD_DIALECT_TYPE_LIST paddle::dialect::DenseTensorType /// /// \brief Define built-in parametric types. diff --git a/paddle/ir/builtin_attribute.cc b/paddle/ir/builtin_attribute.cc index 40c0204ce9c..ba95318df6f 100644 --- a/paddle/ir/builtin_attribute.cc +++ b/paddle/ir/builtin_attribute.cc @@ -19,4 +19,18 @@ std::string StrAttribute::data() const { return storage()->GetAsKey(); } uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); } +bool BoolAttribute::data() const { return storage()->GetAsKey(); } + +float FloatAttribute::data() const { return storage()->GetAsKey(); } + +double DoubleAttribute::data() const { return storage()->GetAsKey(); } + +int32_t Int32_tAttribute::data() const { return storage()->GetAsKey(); } + +int64_t Int64_tAttribute::data() const { return storage()->GetAsKey(); } + +std::vector ArrayAttribute::data() const { + return storage()->GetAsKey(); +} + } // namespace ir diff --git a/paddle/ir/builtin_attribute.h b/paddle/ir/builtin_attribute.h index 4e2164a1e4d..93eb8599b20 100644 --- a/paddle/ir/builtin_attribute.h +++ b/paddle/ir/builtin_attribute.h @@ -22,9 +22,11 @@ namespace ir { /// /// \brief All built-in attributes. /// -#define GET_BUILT_IN_ATTRIBUTE_LIST ir::StrAttribute +#define GET_BUILT_IN_ATTRIBUTE_LIST \ + StrAttribute, BoolAttribute, FloatAttribute, DoubleAttribute, \ + Int32_tAttribute, Int64_tAttribute, ArrayAttribute -class StrAttribute : public ir::Attribute { +class StrAttribute : public Attribute { public: using Attribute::Attribute; @@ -39,13 +41,64 @@ class StrAttribute : public ir::Attribute { uint32_t size() const; }; -} // namespace ir +class BoolAttribute : public Attribute { + public: + using Attribute::Attribute; -namespace std { -template <> -struct hash { - std::size_t operator()(const ir::StrAttribute &obj) const { - return std::hash()(obj.storage()); - } + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(BoolAttribute, BoolAttributeStorage); + + bool data() const; +}; + +class FloatAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(FloatAttribute, FloatAttributeStorage); + + float data() const; +}; + +class DoubleAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DoubleAttribute, DoubleAttributeStorage); + + double data() const; +}; + +class Int32_tAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32_tAttribute, Int32_tAttributeStorage); + + int32_t data() const; +}; + +class Int64_tAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64_tAttribute, Int64_tAttributeStorage); + + int64_t data() const; +}; + +class ArrayAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage); + + std::vector data() const; + + size_t size() const { return data().size(); } + + bool empty() const { return data().empty(); } + + Attribute operator[](size_t index) const { return data()[index]; } }; -} // namespace std + +} // namespace ir diff --git a/paddle/ir/builtin_attribute_storage.h b/paddle/ir/builtin_attribute_storage.h index f6f97d5e616..a61b2d561df 100644 --- a/paddle/ir/builtin_attribute_storage.h +++ b/paddle/ir/builtin_attribute_storage.h @@ -19,17 +19,41 @@ #include #include "paddle/ir/attribute.h" +#include "paddle/ir/utils.h" namespace ir { + +#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \ + struct concrete_storage : public ir::AttributeStorage { \ + using ParamKey = bool; \ + \ + explicit concrete_storage(const ParamKey &key) { data_ = key; } \ + \ + static concrete_storage *Construct(ParamKey key) { \ + return new concrete_storage(key); \ + } \ + \ + static std::size_t HashValue(const ParamKey &key) { \ + return std::hash()(key); \ + } \ + \ + bool operator==(const ParamKey &key) const { return data_ == key; } \ + \ + ParamKey GetAsKey() const { return data_; } \ + \ + private: \ + ParamKey data_; \ + }; + /// /// \brief Define Parameteric AttributeStorage for StrAttribute. /// -struct StrAttributeStorage : public ir::AttributeStorage { +struct StrAttributeStorage : public AttributeStorage { using ParamKey = std::string; explicit StrAttributeStorage(const ParamKey &key) { data_ = reinterpret_cast(malloc(key.size())); - memcpy(data_, const_cast(key.c_str()), key.size()); + memcpy(data_, key.c_str(), key.size()); size_ = key.size(); } @@ -44,7 +68,7 @@ struct StrAttributeStorage : public ir::AttributeStorage { } bool operator==(const ParamKey &key) const { - return std::equal(data_, data_ + size_, const_cast(key.c_str())); + return std::equal(data_, data_ + size_, key.c_str()); } ParamKey GetAsKey() const { return ParamKey(data_, size_); } @@ -54,4 +78,55 @@ struct StrAttributeStorage : public ir::AttributeStorage { uint32_t size_; }; +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t); + +struct ArrayAttributeStorage : public AttributeStorage { + using ParamKey = std::vector; + + explicit ArrayAttributeStorage(const ParamKey &key) { + data_ = + reinterpret_cast(malloc(sizeof(Attribute) * key.size())); + memcpy(reinterpret_cast(data_), + reinterpret_cast(key.data()), + sizeof(Attribute) * key.size()); + length_ = key.size(); + } + + ~ArrayAttributeStorage() { free(reinterpret_cast(data_)); } + + static ArrayAttributeStorage *Construct(ParamKey key) { + return new ArrayAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + std::size_t hash_value = 0; + for (size_t i = 0; i < key.size(); ++i) { + hash_value = hash_combine(hash_value, std::hash()(key[i])); + } + return hash_value; + } + + bool operator==(const ParamKey &key) const { + if (key.size() != length_) { + return false; + } + for (size_t i = 0; i < length_; ++i) { + if (data_[i] != key[i]) { + return false; + } + } + return true; + } + + ParamKey GetAsKey() const { return ParamKey(data_, data_ + length_); } + + private: + Attribute *data_ = nullptr; + size_t length_ = 0; +}; + } // namespace ir diff --git a/paddle/ir/builtin_type.cc b/paddle/ir/builtin_type.cc new file mode 100644 index 00000000000..8929496b5de --- /dev/null +++ b/paddle/ir/builtin_type.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/builtin_type.h" + +namespace ir { +std::vector VectorType::data() const { return storage()->GetAsKey(); } + +} // namespace ir diff --git a/paddle/ir/builtin_type.h b/paddle/ir/builtin_type.h index d195cca2939..803638750cb 100644 --- a/paddle/ir/builtin_type.h +++ b/paddle/ir/builtin_type.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/ir/builtin_type_storage.h" #include "paddle/ir/type.h" namespace ir { @@ -22,9 +23,9 @@ namespace ir { /// The built-in Dialect will use this macro to quickly register all built-in /// types. /// -#define GET_BUILT_IN_TYPE_LIST \ - ir::Float16Type, ir::Float32Type, ir::Float64Type, ir::Int16Type, \ - ir::Int32Type, ir::Int64Type +#define GET_BUILT_IN_TYPE_LIST \ + BFloat16Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, \ + Int32Type, Int64Type, BoolType, VectorType /// /// \brief Define built-in parameterless types. Please add the necessary @@ -42,58 +43,96 @@ namespace ir { /// Type fp32 = Float32Type::get(ctx); /// \endcode /// -class Float16Type : public ir::Type { +class BFloat16Type : public Type { public: using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(Float16Type, ir::TypeStorage); + DECLARE_TYPE_UTILITY_FUNCTOR(BFloat16Type, TypeStorage); +}; + +class Float16Type : public Type { + public: + using Type::Type; - static Float16Type get(ir::IrContext *context); + DECLARE_TYPE_UTILITY_FUNCTOR(Float16Type, TypeStorage); + + static Float16Type get(IrContext *context); }; -class Float32Type : public ir::Type { +class Float32Type : public Type { public: using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(Float32Type, ir::TypeStorage); + DECLARE_TYPE_UTILITY_FUNCTOR(Float32Type, TypeStorage); - static Float32Type get(ir::IrContext *context); + static Float32Type get(IrContext *context); }; -class Float64Type : public ir::Type { +class Float64Type : public Type { public: using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(Float64Type, ir::TypeStorage); + DECLARE_TYPE_UTILITY_FUNCTOR(Float64Type, TypeStorage); - static Float64Type get(ir::IrContext *context); + static Float64Type get(IrContext *context); }; -class Int16Type : public ir::Type { +class Int8Type : public Type { public: using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(Int16Type, ir::TypeStorage); + DECLARE_TYPE_UTILITY_FUNCTOR(Int8Type, TypeStorage); +}; - static Int16Type get(ir::IrContext *context); +class Int16Type : public Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Int16Type, TypeStorage); + + static Int16Type get(IrContext *context); }; -class Int32Type : public ir::Type { +class Int32Type : public Type { public: using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(Int32Type, ir::TypeStorage); + DECLARE_TYPE_UTILITY_FUNCTOR(Int32Type, TypeStorage); - static Int32Type get(ir::IrContext *context); + static Int32Type get(IrContext *context); }; -class Int64Type : public ir::Type { +class Int64Type : public Type { public: using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(Int64Type, ir::TypeStorage); + DECLARE_TYPE_UTILITY_FUNCTOR(Int64Type, TypeStorage); + + static Int64Type get(IrContext *context); +}; + +class BoolType : public Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(BoolType, TypeStorage); + + static BoolType get(IrContext *context); +}; + +class VectorType : public Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(VectorType, VectorTypeStorage); + + std::vector data() const; + + size_t size() const { return data().size(); } + + bool empty() const { return data().empty(); } - static Int64Type get(ir::IrContext *context); + Type operator[](size_t index) const { return data()[index]; } }; } // namespace ir diff --git a/paddle/ir/builtin_type_storage.h b/paddle/ir/builtin_type_storage.h new file mode 100644 index 00000000000..576457fe119 --- /dev/null +++ b/paddle/ir/builtin_type_storage.h @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/type.h" +#include "paddle/ir/utils.h" + +namespace ir { +struct VectorTypeStorage : public TypeStorage { + using ParamKey = std::vector; + + explicit VectorTypeStorage(const ParamKey &key) { + data_ = reinterpret_cast(malloc(key.size() * sizeof(Type))); + memcpy(reinterpret_cast(data_), + reinterpret_cast(key.data()), + key.size() * sizeof(Type)); + size_ = key.size(); + } + + ~VectorTypeStorage() { free(data_); } + + /// + /// \brief Each derived TypeStorage must define a Construc method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static VectorTypeStorage *Construct(ParamKey key) { + return new VectorTypeStorage(key); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey &key) { + std::size_t hash_value = 0; + for (size_t i = 0; i < key.size(); ++i) { + hash_value = hash_combine(hash_value, std::hash()(key[i])); + } + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey &key) const { + if (key.size() != size_) { + return false; + } + for (size_t i = 0; i < size_; ++i) { + if (data_[i] != key[i]) { + return false; + } + } + return true; + } + + ParamKey GetAsKey() const { return ParamKey(data_, data_ + size_); } + + /// + /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, + /// layout, lod, offset. + /// + Type *data_; + size_t size_; +}; + +} // namespace ir diff --git a/test/cpp/ir/type_test.cc b/test/cpp/ir/type_test.cc index 8613c9d6afa..36917e541e3 100644 --- a/test/cpp/ir/type_test.cc +++ b/test/cpp/ir/type_test.cc @@ -128,6 +128,16 @@ TEST(type_test, built_in_type) { EXPECT_EQ(fp16_1.isa(), true); EXPECT_EQ(fp16_1.isa(), false); EXPECT_EQ(fp16_1.isa(), true); + + // Test 3: Test VectorType + std::vector vec_type = {int32_1, int64_1}; + ir::Type vector_type = ir::VectorType::get(ctx, vec_type); + EXPECT_EQ(vector_type.isa(), true); + EXPECT_EQ(vector_type.dyn_cast().size() == 2, true); + EXPECT_EQ(vector_type.dyn_cast()[0].isa(), + true); + EXPECT_EQ(vector_type.dyn_cast()[1].isa(), + true); } // Customize a parameterized TypeStorage IntegerTypeStorage. -- GitLab