From 3a3ff942fb347cf6e60b9e8830f28036f72d4358 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 14 Mar 2023 11:19:42 +0800 Subject: [PATCH] [IR] Type system stage4: Add some built-in types and type conversion methods (#51112) * add builtin-type DenseTensorType Float16Type Float64Type Int16Type Int64Type * refine comment * refine comment * add classof for Type class * refine test code * add get param func for DenseTensorType * add dyn_cast and refine isa * set default WITH_NEWIR=OFF * refine cast_utils * Refine code by comment * refine code by comment * refine code by comment * refine code by comment * fix bug of dyn_cast * set WITH_NEWIR=OFF * refine code by comment --- paddle/ir/builtin_type.cc | 34 +++++++ paddle/ir/builtin_type.h | 75 ++++++++++++++- paddle/ir/builtin_type_storage.h | 156 ++++++++++++++++++++++++++++++ paddle/ir/cast_utils.h | 157 +++++++++++++++++++++++++++++++ paddle/ir/ir_context.cc | 20 +++- paddle/ir/tests/type_test.cc | 112 +++++++++++++++++----- paddle/ir/type.h | 22 ++++- 7 files changed, 548 insertions(+), 28 deletions(-) create mode 100644 paddle/ir/builtin_type.cc create mode 100644 paddle/ir/builtin_type_storage.h create mode 100644 paddle/ir/cast_utils.h diff --git a/paddle/ir/builtin_type.cc b/paddle/ir/builtin_type.cc new file mode 100644 index 00000000000..5e18b945016 --- /dev/null +++ b/paddle/ir/builtin_type.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/builtin_type.h" + +namespace ir { +const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } + +const ir::DenseTensorTypeStorage::Dim& DenseTensorType::dim() const { + return storage()->dims_; +} + +const ir::DenseTensorTypeStorage::DataLayout& DenseTensorType::data_layout() + const { + return storage()->layout_; +} + +const ir::DenseTensorTypeStorage::LoD& DenseTensorType::lod() const { + return storage()->lod_; +} + +const size_t& DenseTensorType::offset() const { return storage()->offset_; } +} // namespace ir diff --git a/paddle/ir/builtin_type.h b/paddle/ir/builtin_type.h index 8b15ae6eed0..0f4db31d9d6 100644 --- a/paddle/ir/builtin_type.h +++ b/paddle/ir/builtin_type.h @@ -14,22 +14,44 @@ #pragma once +#include "paddle/ir/builtin_type_storage.h" #include "paddle/ir/type.h" namespace ir { /// /// \brief This macro is used to get a list of all built-in types in this file. +/// The built-in Dialect will use this macro to quickly register all built-in +/// types. /// -#define GET_BUILT_IN_TYPE_LIST ir::Float32Type, ir::Int32Type +#define GET_BUILT_IN_TYPE_LIST \ + ir::Float16Type, ir::Float32Type, ir::Float64Type, ir::Int16Type, \ + ir::Int32Type, ir::Int64Type, ir::DenseTensorType /// -/// \brief Definitions of built-in type classes. The built-in type object get -/// method is as follows: +/// \brief Define built-in parameterless types. Please add the necessary +/// interface functions for built-in types through the macro +/// DECLARE_TYPE_UTILITY_FUNCTOR. +/// +/// NOTE(zhangbo9674): If you need to directly +/// cache the object of this built-in type in IrContext, please overload the get +/// method, and construct and cache the object in IrContext. For the specific +/// implementation method, please refer to Float16Type. +/// +/// The built-in type object get method is as follows: /// \code{cpp} /// ir::IrContext *ctx = ir::IrContext::Instance(); /// Type fp32 = Float32Type::get(ctx); /// \endcode /// +class Float16Type : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Float16Type, ir::TypeStorage); + + static Float16Type get(ir::IrContext *context); +}; + class Float32Type : public ir::Type { public: using Type::Type; @@ -39,6 +61,24 @@ class Float32Type : public ir::Type { static Float32Type get(ir::IrContext *context); }; +class Float64Type : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Float64Type, ir::TypeStorage); + + static Float64Type get(ir::IrContext *context); +}; + +class Int16Type : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Int16Type, ir::TypeStorage); + + static Int16Type get(ir::IrContext *context); +}; + class Int32Type : public ir::Type { public: using Type::Type; @@ -48,4 +88,33 @@ class Int32Type : public ir::Type { static Int32Type get(ir::IrContext *context); }; +class Int64Type : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Int64Type, ir::TypeStorage); + + static Int64Type get(ir::IrContext *context); +}; + +/// +/// \brief Define built-in parameteric types. +/// +class DenseTensorType : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage); + + const ir::Type &dtype() const; + + const ir::DenseTensorTypeStorage::Dim &dim() const; + + const ir::DenseTensorTypeStorage::DataLayout &data_layout() const; + + const ir::DenseTensorTypeStorage::LoD &lod() const; + + const size_t &offset() const; +}; + } // namespace ir diff --git a/paddle/ir/builtin_type_storage.h b/paddle/ir/builtin_type_storage.h new file mode 100644 index 00000000000..876b6ceeffd --- /dev/null +++ b/paddle/ir/builtin_type_storage.h @@ -0,0 +1,156 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/ir/type.h" + +namespace std { +/// +/// \brief Enable hashing std::vector instances. +/// +template +struct hash> { + std::size_t operator()(const std::vector &dim) const { + std::size_t seed = 0; + for (size_t i = 0; i < dim.size(); ++i) { + seed ^= std::hash()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +} // namespace std + +namespace ir { +/// +/// \brief Define Parameteric TypeStorage for DenseTensorType. +/// +/// NOTE(zhangbo9674): The derived TypeStorage class needs to implement the +/// following methods: (1)declare ParamKey, (2)define Construction method, +/// (3)define HashValue method, (4)overload operator==. +/// +struct DenseTensorTypeStorage : public ir::TypeStorage { + /// + /// \brief It is consistent with the DataLayout defined by Phi operator + /// library. See the file for details: paddle/phi/common/layout.h. + /// + enum class DataLayout : unsigned int { + UNDEFINED = 0, + NHWC, + NCHW, + NCDHW, + NDHWC, + ONEDNN, + SPARSE_COO, + SPARSE_CSR, + PSTRING_UNION, + + NUM_DATA_LAYOUTS, + + // See Note [ Why we need ALL in basic kernel key member? ] + ALL_LAYOUT = UNDEFINED, + + // Note: Unify phi DataLayout and fluid::framework::DataLayout, + // for compatible with fluid DataLayout, here need prefix `k` + kNHWC = NHWC, + kNCHW = NCHW, + kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally + kNDHWC = NDHWC, + kNCDHW = NCDHW, + }; + + using Dim = std::vector; + + using LoD = std::vector>; + + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = std::tuple; + + DenseTensorTypeStorage( + ir::Type dtype, Dim dims, DataLayout layout, LoD lod, size_t offset) + : dtype_(dtype), + dims_(dims), + layout_(layout), + lod_(lod), + offset_(offset) {} + + /// + /// \brief Each derived TypeStorage must define a Construc method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static DenseTensorTypeStorage *Construct(ParamKey key) { + return new DenseTensorTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey &key) { + std::size_t hash_value = 0; + // hash dtype + hash_value = + hash_combine(hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = hash_combine(hash_value, std::hash()(std::get<1>(key))); + // hash layout + hash_value = + hash_combine(hash_value, + std::hash::type>()( + static_cast::type>( + std::get<2>(key)))); + // hash lod + hash_value = hash_combine(hash_value, std::hash()(std::get<3>(key))); + // hash offset + hash_value = + hash_combine(hash_value, std::hash()(std::get<4>(key))); + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey &key) const { + return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, dims_, layout_, lod_, offset_); + } + + /// + /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, + /// layout, lod, offset. + /// + ir::Type dtype_; + Dim dims_; + DataLayout layout_; + LoD lod_; + size_t offset_; + + private: + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + } +}; + +} // namespace ir diff --git a/paddle/ir/cast_utils.h b/paddle/ir/cast_utils.h new file mode 100644 index 00000000000..dcc4b89fe8b --- /dev/null +++ b/paddle/ir/cast_utils.h @@ -0,0 +1,157 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ir { +/// +/// \brief The template function actually called by isa_wrap. +/// +template +struct isa_impl { + static inline bool call(const From &Val) { return Target::classof(Val); } +}; + +template +struct isa_impl< + Target, + From, + typename std::enable_if::value>::type> { + static inline bool call(const From &) { return true; } +}; + +/// +/// \brief The template function actually called by isa. +/// +template +struct isa_wrap { + static inline bool call(const From &Val) { + return isa_impl::call(Val); + } +}; + +/// +/// \brief typequalified specialization of the isa_wrap template parameter From. +/// Specialized types include: const T, T*, const T*, T* const, const T* const. +/// +template +struct isa_wrap { + static inline bool call(const From &Val) { + return isa_impl::call(Val); + } +}; + +template +struct isa_wrap< + Target, + From, + typename std::enable_if_t>::value>> { + static inline bool call( + std::remove_pointer_t> const *Val) { + if (Val == nullptr) { + throw("isa<> used on a null pointer"); + } + return isa_impl>>::call( + *Val); + } +}; + +/// +/// \brief isa template function, used to determine whether the value is a +/// Target type. Using method: if (isa(value)) { ... }. +/// +template +inline bool isa(const From &Val) { + return isa_wrap::type, From>::call(Val); +} + +/// +/// \brief Derive cast return type by template parameter From and To. +/// +template +struct ReturnTypeDuductionWrap { + typedef To &type; +}; + +template +struct ReturnTypeDuductionWrap { + typedef const To &type; +}; + +template +struct ReturnTypeDuductionWrap { + typedef To *type; +}; + +template +struct ReturnTypeDuductionWrap { + typedef const To *type; +}; + +template +struct ReturnTypeDuductionWrap { + typedef const To *type; +}; + +template +struct ReturnTypeDuduction { + typedef typename ReturnTypeDuductionWrap::type type; +}; + +/// +/// cast From to To +/// +template +struct cast_impl { + // This _is_ a simple type, just cast it. + static typename ReturnTypeDuduction::type call(const From &Val) { + typename ReturnTypeDuduction::type ret = + (typename ReturnTypeDuduction::type) const_cast(Val); + return ret; + } +}; + +template +inline typename ReturnTypeDuduction::type cast(From &Val) { // NOLINT + if (!isa(Val)) { + throw("cast() argument of incompatible type!"); + } + return cast_impl::call(Val); +} + +template +inline typename ReturnTypeDuduction::type cast(From *Val) { + if (!isa(Val)) { + throw("cast() argument of incompatible type!"); + } + return cast_impl::call(Val); +} + +/// +/// \brief dyn_cast From to To. +/// +template +inline std::decay_t::type> dyn_cast( + From &Val) { // NOLINT + return isa(Val) ? cast(Val) : nullptr; +} + +template +inline typename ReturnTypeDuduction::type dyn_cast(From *Val) { + return isa(Val) ? cast(Val) : nullptr; +} + +} // namespace ir diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc index 6fca67c16e2..c237b8baeab 100644 --- a/paddle/ir/ir_context.cc +++ b/paddle/ir/ir_context.cc @@ -85,7 +85,6 @@ class IrContextImpl { // Cached AbstractType instances. std::unordered_map registed_abstract_types_; - ir::SpinLock registed_abstract_types_lock_; // TypeStorage uniquer and cache instances. @@ -93,12 +92,15 @@ class IrContextImpl { // The dialcet registered in the context. std::unordered_map registed_dialect_; - ir::SpinLock registed_dialect_lock_; - // Some built-in types. + // Cache some built-in type objects. + Float16Type fp16_type; Float32Type fp32_type; + Float64Type fp64_type; + Int16Type int16_type; Int32Type int32_type; + Int64Type int64_type; ir::SpinLock destructor_lock_; }; @@ -113,8 +115,12 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { GetOrRegisterDialect(); VLOG(4) << "=============================================="; + impl_->fp16_type = TypeManager::get(this); impl_->fp32_type = TypeManager::get(this); + impl_->fp64_type = TypeManager::get(this); + impl_->int16_type = TypeManager::get(this); impl_->int32_type = TypeManager::get(this); + impl_->int64_type = TypeManager::get(this); } void IrContext::RegisterAbstractType(ir::TypeId type_id, @@ -173,8 +179,16 @@ const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { } } +Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; } + Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; } +Float64Type Float64Type::get(IrContext *ctx) { return ctx->impl().fp64_type; } + +Int16Type Int16Type::get(IrContext *ctx) { return ctx->impl().int16_type; } + Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } +Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; } + } // namespace ir diff --git a/paddle/ir/tests/type_test.cc b/paddle/ir/tests/type_test.cc index e8901be6c35..a11040e3656 100644 --- a/paddle/ir/tests/type_test.cc +++ b/paddle/ir/tests/type_test.cc @@ -70,29 +70,99 @@ TEST(type_test, type_base) { } TEST(type_test, built_in_type) { - // Test 1: Test the built-in type of IrContext. + // Test the interfaces of class Type: judgment, type_id, abstract_type, + // classof. ir::IrContext *ctx = ir::IrContext::Instance(); - ir::Type fp32_1 = ir::Float32Type::get(ctx); - // Test 2: Test the interfaces of class Type: judgment, type_id, - // abstract_type, classof. + // Test 1: Test the parameterless built-in type of IrContext. + ir::Type fp16_1 = ir::Float16Type::get(ctx); + ir::Type fp16_2 = ir::Float16Type::get(ctx); + EXPECT_EQ(fp16_1, fp16_2); + EXPECT_EQ(fp16_1.type_id(), fp16_2.type_id()); + EXPECT_EQ(&fp16_1.abstract_type(), + &ir::AbstractType::lookup(fp16_1.type_id(), ctx)); + EXPECT_EQ(ir::Float16Type::classof(fp16_1), 1); + + ir::Type fp32_1 = ir::Float32Type::get(ctx); ir::Type fp32_2 = ir::Float32Type::get(ctx); - EXPECT_EQ(fp32_1 == fp32_2, 1); - EXPECT_EQ(fp32_1 != fp32_2, 0); - EXPECT_EQ(fp32_1.type_id() == fp32_2.type_id(), 1); - EXPECT_EQ(&fp32_1.abstract_type() == - &ir::AbstractType::lookup(fp32_1.type_id(), ctx), - 1); + EXPECT_EQ(fp32_1, fp32_2); + EXPECT_EQ(fp32_1.type_id(), fp32_2.type_id()); + EXPECT_EQ(&fp32_1.abstract_type(), + &ir::AbstractType::lookup(fp32_1.type_id(), ctx)); EXPECT_EQ(ir::Float32Type::classof(fp32_1), 1); + ir::Type fp64_1 = ir::Float64Type::get(ctx); + ir::Type fp64_2 = ir::Float64Type::get(ctx); + EXPECT_EQ(fp64_1, fp64_2); + EXPECT_EQ(fp64_1.type_id(), fp64_2.type_id()); + EXPECT_EQ(&fp64_1.abstract_type(), + &ir::AbstractType::lookup(fp64_1.type_id(), ctx)); + EXPECT_EQ(ir::Float64Type::classof(fp64_1), 1); + + ir::Type int16_1 = ir::Int16Type::get(ctx); + ir::Type int16_2 = ir::Int16Type::get(ctx); + EXPECT_EQ(int16_1, int16_2); + EXPECT_EQ(int16_1.type_id(), int16_2.type_id()); + EXPECT_EQ(&int16_1.abstract_type(), + &ir::AbstractType::lookup(int16_1.type_id(), ctx)); + EXPECT_EQ(ir::Int16Type::classof(int16_1), 1); + ir::Type int32_1 = ir::Int32Type::get(ctx); ir::Type int32_2 = ir::Int32Type::get(ctx); - EXPECT_EQ(int32_1 == int32_2, 1); - EXPECT_EQ(int32_1.type_id() == int32_2.type_id(), 1); - EXPECT_EQ(&int32_1.abstract_type() == - &ir::AbstractType::lookup(int32_1.type_id(), ctx), - 1); + EXPECT_EQ(int32_1, int32_2); + EXPECT_EQ(int32_1.type_id(), int32_2.type_id()); + EXPECT_EQ(&int32_1.abstract_type(), + &ir::AbstractType::lookup(int32_1.type_id(), ctx)); EXPECT_EQ(ir::Int32Type::classof(int32_1), 1); + + ir::Type int64_1 = ir::Int64Type::get(ctx); + ir::Type int64_2 = ir::Int64Type::get(ctx); + EXPECT_EQ(int64_1, int64_2); + EXPECT_EQ(int64_1.type_id(), int64_2.type_id()); + EXPECT_EQ(&int64_1.abstract_type(), + &ir::AbstractType::lookup(int64_1.type_id(), ctx)); + EXPECT_EQ(ir::Int64Type::classof(int64_1), 1); + + // Test 2: Test the parameteric built-in type of IrContext. + ir::DenseTensorTypeStorage::Dim dims = {1, 2, 3}; + ir::DenseTensorTypeStorage::DataLayout data_layout = + ir::DenseTensorTypeStorage::DataLayout::NCHW; + ir::DenseTensorTypeStorage::LoD lod = {{1, 2, 3}, {4, 5, 6}}; + size_t offset = 0; + + ir::Type dense_tensor_1 = + ir::DenseTensorType::get(ctx, fp32_1, dims, data_layout, lod, offset); + ir::Type dense_tensor_2 = + ir::DenseTensorType::get(ctx, fp32_2, dims, data_layout, lod, offset); + ir::Type dense_tensor_3 = + ir::DenseTensorType::get(ctx, fp32_1, dims, data_layout, lod, 2); + + EXPECT_EQ(dense_tensor_1, dense_tensor_2); + EXPECT_NE(dense_tensor_1, dense_tensor_3); + EXPECT_EQ(dense_tensor_1.type_id(), dense_tensor_2.type_id()); + EXPECT_EQ(ir::DenseTensorType::classof(dense_tensor_1), 1); + + ir::DenseTensorType dense_tensor_4 = + ir::DenseTensorType::get(ctx, fp32_1, dims, data_layout, lod, 2); + EXPECT_EQ(dense_tensor_4.offset() == 2, 1); + EXPECT_EQ(dense_tensor_4.dtype().isa(), true); + EXPECT_EQ(dense_tensor_4.data_layout(), data_layout); + + // Test 3: Test isa and dyn_cast. + EXPECT_EQ(fp16_1.isa(), true); + EXPECT_EQ(fp16_1.isa(), false); + EXPECT_EQ(fp16_1.isa(), false); + EXPECT_EQ(fp16_1.isa(), true); + EXPECT_EQ(dense_tensor_1.isa(), true); + + ir::DenseTensorType dense_tensor_cast_1 = + dense_tensor_1.dyn_cast(); + EXPECT_EQ(dense_tensor_cast_1.isa(), true); + EXPECT_EQ(dense_tensor_cast_1.offset() == 0, 1); + const ir::DenseTensorType dense_tensor_cast_2 = + ir::dyn_cast(dense_tensor_1); + EXPECT_EQ(dense_tensor_cast_2.isa(), true); + EXPECT_EQ(dense_tensor_cast_2.offset() == 0, 1); } // Customize a parameterized TypeStorage IntegerTypeStorage. @@ -150,15 +220,15 @@ TEST(type_test, custom_type_dialect) { ir::Type int1_1 = IntegerType::get(ctx, 1, 0); ir::Type int1_2 = IntegerType::get(ctx, 1, 0); - EXPECT_EQ(int1_1 == int1_2, 1); + EXPECT_EQ(int1_1, int1_2); ir::Type int8 = IntegerType::get(ctx, 8, 0); - EXPECT_EQ(int8 == int1_2, 0); + EXPECT_NE(int8, int1_2); // Test 2: Test Dialect interfaces - EXPECT_EQ(ctx == int8.ir_context(), 1); + EXPECT_EQ(ctx, int8.ir_context()); - EXPECT_EQ(int8.dialect().id() == ir::TypeId::get(), 1); + EXPECT_EQ(int8.dialect().id(), ir::TypeId::get()); std::vector dialect_list = ctx->GetRegisteredDialects(); EXPECT_EQ(dialect_list.size() == 3, 1); // integer, builtin, fake @@ -166,9 +236,9 @@ TEST(type_test, custom_type_dialect) { ir::Dialect *dialect_builtin1 = ctx->GetRegisteredDialect("builtin"); ir::Dialect *dialect_builtin2 = ctx->GetRegisteredDialect(); - EXPECT_EQ(dialect_builtin1 == dialect_builtin2, 1); + EXPECT_EQ(dialect_builtin1, dialect_builtin2); ir::Dialect *dialect_integer1 = ctx->GetRegisteredDialect("integer"); ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect(); - EXPECT_EQ(dialect_integer1 == dialect_integer2, 1); + EXPECT_EQ(dialect_integer1, dialect_integer2); } diff --git a/paddle/ir/type.h b/paddle/ir/type.h index 64e65e99688..79465d3aaf2 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/ir/cast_utils.h" #include "paddle/ir/type_base.h" namespace ir { @@ -37,15 +38,19 @@ class Type { Type &operator=(const Type &other) = default; /// - /// \brief Comparison operations. + /// \brief Some operators are overloaded. /// bool operator==(Type other) const { return storage_ == other.storage_; } + bool operator!=(Type other) const { return storage_ != other.storage_; } explicit operator bool() const { return storage_; } bool operator!() const { return storage_ == nullptr; } + /// + /// \brief Some type attribute acquisition interfaces. + /// TypeId type_id() { return storage_->abstract_type().type_id(); } const AbstractType &abstract_type() { return storage_->abstract_type(); } @@ -56,6 +61,21 @@ class Type { IrContext *ir_context() const; + /// + /// \brief Methods for type judgment and cast. + /// + static bool classof(Type) { return true; } + + template + bool isa() const { + return ir::isa(*this); + } + + template + U dyn_cast() const { + return ir::dyn_cast(*this); + } + /// /// \brief Enable hashing Type. /// -- GitLab