未验证 提交 3a3ff942 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Type system stage4: Add some built-in types and type conversion methods (#51112)

* add builtin-type DenseTensorType Float16Type Float64Type Int16Type Int64Type

* refine comment

* refine comment

* add classof for Type class

* refine test code

* add get param func for DenseTensorType

* add dyn_cast and refine isa

* set default WITH_NEWIR=OFF

* refine cast_utils

* Refine code by comment

* refine code by comment

* refine code by comment

* refine code by comment

* fix bug of dyn_cast

* set WITH_NEWIR=OFF

* refine code by comment
上级 0e900864
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builtin_type.h"
namespace ir {
const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
const ir::DenseTensorTypeStorage::Dim& DenseTensorType::dim() const {
return storage()->dims_;
}
const ir::DenseTensorTypeStorage::DataLayout& DenseTensorType::data_layout()
const {
return storage()->layout_;
}
const ir::DenseTensorTypeStorage::LoD& DenseTensorType::lod() const {
return storage()->lod_;
}
const size_t& DenseTensorType::offset() const { return storage()->offset_; }
} // namespace ir
...@@ -14,22 +14,44 @@ ...@@ -14,22 +14,44 @@
#pragma once #pragma once
#include "paddle/ir/builtin_type_storage.h"
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
namespace ir { namespace ir {
/// ///
/// \brief This macro is used to get a list of all built-in types in this file. /// \brief This macro is used to get a list of all built-in types in this file.
/// The built-in Dialect will use this macro to quickly register all built-in
/// types.
/// ///
#define GET_BUILT_IN_TYPE_LIST ir::Float32Type, ir::Int32Type #define GET_BUILT_IN_TYPE_LIST \
ir::Float16Type, ir::Float32Type, ir::Float64Type, ir::Int16Type, \
ir::Int32Type, ir::Int64Type, ir::DenseTensorType
/// ///
/// \brief Definitions of built-in type classes. The built-in type object get /// \brief Define built-in parameterless types. Please add the necessary
/// method is as follows: /// interface functions for built-in types through the macro
/// DECLARE_TYPE_UTILITY_FUNCTOR.
///
/// NOTE(zhangbo9674): If you need to directly
/// cache the object of this built-in type in IrContext, please overload the get
/// method, and construct and cache the object in IrContext. For the specific
/// implementation method, please refer to Float16Type.
///
/// The built-in type object get method is as follows:
/// \code{cpp} /// \code{cpp}
/// ir::IrContext *ctx = ir::IrContext::Instance(); /// ir::IrContext *ctx = ir::IrContext::Instance();
/// Type fp32 = Float32Type::get(ctx); /// Type fp32 = Float32Type::get(ctx);
/// \endcode /// \endcode
/// ///
class Float16Type : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Float16Type, ir::TypeStorage);
static Float16Type get(ir::IrContext *context);
};
class Float32Type : public ir::Type { class Float32Type : public ir::Type {
public: public:
using Type::Type; using Type::Type;
...@@ -39,6 +61,24 @@ class Float32Type : public ir::Type { ...@@ -39,6 +61,24 @@ class Float32Type : public ir::Type {
static Float32Type get(ir::IrContext *context); static Float32Type get(ir::IrContext *context);
}; };
class Float64Type : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Float64Type, ir::TypeStorage);
static Float64Type get(ir::IrContext *context);
};
class Int16Type : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int16Type, ir::TypeStorage);
static Int16Type get(ir::IrContext *context);
};
class Int32Type : public ir::Type { class Int32Type : public ir::Type {
public: public:
using Type::Type; using Type::Type;
...@@ -48,4 +88,33 @@ class Int32Type : public ir::Type { ...@@ -48,4 +88,33 @@ class Int32Type : public ir::Type {
static Int32Type get(ir::IrContext *context); static Int32Type get(ir::IrContext *context);
}; };
class Int64Type : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int64Type, ir::TypeStorage);
static Int64Type get(ir::IrContext *context);
};
///
/// \brief Define built-in parameteric types.
///
class DenseTensorType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage);
const ir::Type &dtype() const;
const ir::DenseTensorTypeStorage::Dim &dim() const;
const ir::DenseTensorTypeStorage::DataLayout &data_layout() const;
const ir::DenseTensorTypeStorage::LoD &lod() const;
const size_t &offset() const;
};
} // namespace ir } // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include "paddle/ir/type.h"
namespace std {
///
/// \brief Enable hashing std::vector<T> instances.
///
template <typename T>
struct hash<std::vector<T>> {
std::size_t operator()(const std::vector<T> &dim) const {
std::size_t seed = 0;
for (size_t i = 0; i < dim.size(); ++i) {
seed ^= std::hash<T>()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
} // namespace std
namespace ir {
///
/// \brief Define Parameteric TypeStorage for DenseTensorType.
///
/// NOTE(zhangbo9674): The derived TypeStorage class needs to implement the
/// following methods: (1)declare ParamKey, (2)define Construction method,
/// (3)define HashValue method, (4)overload operator==.
///
struct DenseTensorTypeStorage : public ir::TypeStorage {
///
/// \brief It is consistent with the DataLayout defined by Phi operator
/// library. See the file for details: paddle/phi/common/layout.h.
///
enum class DataLayout : unsigned int {
UNDEFINED = 0,
NHWC,
NCHW,
NCDHW,
NDHWC,
ONEDNN,
SPARSE_COO,
SPARSE_CSR,
PSTRING_UNION,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in basic kernel key member? ]
ALL_LAYOUT = UNDEFINED,
// Note: Unify phi DataLayout and fluid::framework::DataLayout,
// for compatible with fluid DataLayout, here need prefix `k`
kNHWC = NHWC,
kNCHW = NCHW,
kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally
kNDHWC = NDHWC,
kNCDHW = NCDHW,
};
using Dim = std::vector<int64_t>;
using LoD = std::vector<std::vector<size_t>>;
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = std::tuple<ir::Type, Dim, DataLayout, LoD, size_t>;
DenseTensorTypeStorage(
ir::Type dtype, Dim dims, DataLayout layout, LoD lod, size_t offset)
: dtype_(dtype),
dims_(dims),
layout_(layout),
lod_(lod),
offset_(offset) {}
///
/// \brief Each derived TypeStorage must define a Construc method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static DenseTensorTypeStorage *Construct(ParamKey key) {
return new DenseTensorTypeStorage(std::get<0>(key),
std::get<1>(key),
std::get<2>(key),
std::get<3>(key),
std::get<4>(key));
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey &key) {
std::size_t hash_value = 0;
// hash dtype
hash_value =
hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims
hash_value = hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
// hash layout
hash_value =
hash_combine(hash_value,
std::hash<std::underlying_type<DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>(
std::get<2>(key))));
// hash lod
hash_value = hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
// hash offset
hash_value =
hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
return hash_value;
}
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey &key) const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key;
}
ParamKey GetAsKey() const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_);
}
///
/// \brief DenseTensorTypeStorage include five parameters: dims, dtype,
/// layout, lod, offset.
///
ir::Type dtype_;
Dim dims_;
DataLayout layout_;
LoD lod_;
size_t offset_;
private:
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
namespace ir {
///
/// \brief The template function actually called by isa_wrap.
///
template <typename Target, typename From, typename Enabler = void>
struct isa_impl {
static inline bool call(const From &Val) { return Target::classof(Val); }
};
template <typename Target, typename From>
struct isa_impl<
Target,
From,
typename std::enable_if<std::is_base_of<Target, From>::value>::type> {
static inline bool call(const From &) { return true; }
};
///
/// \brief The template function actually called by isa.
///
template <typename Target, typename From, typename Enable = void>
struct isa_wrap {
static inline bool call(const From &Val) {
return isa_impl<Target, From>::call(Val);
}
};
///
/// \brief typequalified specialization of the isa_wrap template parameter From.
/// Specialized types include: const T, T*, const T*, T* const, const T* const.
///
template <typename Target, typename From>
struct isa_wrap<Target, const From> {
static inline bool call(const From &Val) {
return isa_impl<Target, From>::call(Val);
}
};
template <typename Target, typename From>
struct isa_wrap<
Target,
From,
typename std::enable_if_t<std::is_pointer<std::decay_t<From>>::value>> {
static inline bool call(
std::remove_pointer_t<std::decay_t<From>> const *Val) {
if (Val == nullptr) {
throw("isa<> used on a null pointer");
}
return isa_impl<Target, std::remove_pointer_t<std::decay_t<From>>>::call(
*Val);
}
};
///
/// \brief isa template function, used to determine whether the value is a
/// Target type. Using method: if (isa<Target_Type>(value)) { ... }.
///
template <typename Target, typename From>
inline bool isa(const From &Val) {
return isa_wrap<typename std::remove_pointer<Target>::type, From>::call(Val);
}
///
/// \brief Derive cast return type by template parameter From and To.
///
template <typename To, typename From>
struct ReturnTypeDuductionWrap {
typedef To &type;
};
template <typename To, typename From>
struct ReturnTypeDuductionWrap<To, const From> {
typedef const To &type;
};
template <typename To, typename From>
struct ReturnTypeDuductionWrap<To, From *> {
typedef To *type;
};
template <typename To, typename From>
struct ReturnTypeDuductionWrap<To, const From *> {
typedef const To *type;
};
template <typename To, typename From>
struct ReturnTypeDuductionWrap<To, const From *const> {
typedef const To *type;
};
template <typename To, typename From>
struct ReturnTypeDuduction {
typedef typename ReturnTypeDuductionWrap<To, From>::type type;
};
///
/// cast From to To
///
template <typename To, typename From>
struct cast_impl {
// This _is_ a simple type, just cast it.
static typename ReturnTypeDuduction<To, From>::type call(const From &Val) {
typename ReturnTypeDuduction<To, From>::type ret =
(typename ReturnTypeDuduction<To, From>::type) const_cast<From &>(Val);
return ret;
}
};
template <typename To, typename From>
inline typename ReturnTypeDuduction<To, From>::type cast(From &Val) { // NOLINT
if (!isa<To>(Val)) {
throw("cast<To>() argument of incompatible type!");
}
return cast_impl<To, From>::call(Val);
}
template <typename To, typename From>
inline typename ReturnTypeDuduction<To, From *>::type cast(From *Val) {
if (!isa<To>(Val)) {
throw("cast<To>() argument of incompatible type!");
}
return cast_impl<To, From *>::call(Val);
}
///
/// \brief dyn_cast From to To.
///
template <typename To, typename From>
inline std::decay_t<typename ReturnTypeDuduction<To, From>::type> dyn_cast(
From &Val) { // NOLINT
return isa<To>(Val) ? cast<To>(Val) : nullptr;
}
template <typename To, typename From>
inline typename ReturnTypeDuduction<To, From *>::type dyn_cast(From *Val) {
return isa<To>(Val) ? cast<To>(Val) : nullptr;
}
} // namespace ir
...@@ -85,7 +85,6 @@ class IrContextImpl { ...@@ -85,7 +85,6 @@ class IrContextImpl {
// Cached AbstractType instances. // Cached AbstractType instances.
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_; std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
ir::SpinLock registed_abstract_types_lock_; ir::SpinLock registed_abstract_types_lock_;
// TypeStorage uniquer and cache instances. // TypeStorage uniquer and cache instances.
...@@ -93,12 +92,15 @@ class IrContextImpl { ...@@ -93,12 +92,15 @@ class IrContextImpl {
// The dialcet registered in the context. // The dialcet registered in the context.
std::unordered_map<std::string, Dialect *> registed_dialect_; std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_; ir::SpinLock registed_dialect_lock_;
// Some built-in types. // Cache some built-in type objects.
Float16Type fp16_type;
Float32Type fp32_type; Float32Type fp32_type;
Float64Type fp64_type;
Int16Type int16_type;
Int32Type int32_type; Int32Type int32_type;
Int64Type int64_type;
ir::SpinLock destructor_lock_; ir::SpinLock destructor_lock_;
}; };
...@@ -113,8 +115,12 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { ...@@ -113,8 +115,12 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
GetOrRegisterDialect<BuiltinDialect>(); GetOrRegisterDialect<BuiltinDialect>();
VLOG(4) << "=============================================="; VLOG(4) << "==============================================";
impl_->fp16_type = TypeManager::get<Float16Type>(this);
impl_->fp32_type = TypeManager::get<Float32Type>(this); impl_->fp32_type = TypeManager::get<Float32Type>(this);
impl_->fp64_type = TypeManager::get<Float64Type>(this);
impl_->int16_type = TypeManager::get<Int16Type>(this);
impl_->int32_type = TypeManager::get<Int32Type>(this); impl_->int32_type = TypeManager::get<Int32Type>(this);
impl_->int64_type = TypeManager::get<Int64Type>(this);
} }
void IrContext::RegisterAbstractType(ir::TypeId type_id, void IrContext::RegisterAbstractType(ir::TypeId type_id,
...@@ -173,8 +179,16 @@ const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { ...@@ -173,8 +179,16 @@ const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
} }
} }
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; } Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }
Float64Type Float64Type::get(IrContext *ctx) { return ctx->impl().fp64_type; }
Int16Type Int16Type::get(IrContext *ctx) { return ctx->impl().int16_type; }
Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }
} // namespace ir } // namespace ir
...@@ -70,29 +70,99 @@ TEST(type_test, type_base) { ...@@ -70,29 +70,99 @@ TEST(type_test, type_base) {
} }
TEST(type_test, built_in_type) { TEST(type_test, built_in_type) {
// Test 1: Test the built-in type of IrContext. // Test the interfaces of class Type: judgment, type_id, abstract_type,
// classof.
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type fp32_1 = ir::Float32Type::get(ctx);
// Test 2: Test the interfaces of class Type: judgment, type_id, // Test 1: Test the parameterless built-in type of IrContext.
// abstract_type, classof. ir::Type fp16_1 = ir::Float16Type::get(ctx);
ir::Type fp16_2 = ir::Float16Type::get(ctx);
EXPECT_EQ(fp16_1, fp16_2);
EXPECT_EQ(fp16_1.type_id(), fp16_2.type_id());
EXPECT_EQ(&fp16_1.abstract_type(),
&ir::AbstractType::lookup(fp16_1.type_id(), ctx));
EXPECT_EQ(ir::Float16Type::classof(fp16_1), 1);
ir::Type fp32_1 = ir::Float32Type::get(ctx);
ir::Type fp32_2 = ir::Float32Type::get(ctx); ir::Type fp32_2 = ir::Float32Type::get(ctx);
EXPECT_EQ(fp32_1 == fp32_2, 1); EXPECT_EQ(fp32_1, fp32_2);
EXPECT_EQ(fp32_1 != fp32_2, 0); EXPECT_EQ(fp32_1.type_id(), fp32_2.type_id());
EXPECT_EQ(fp32_1.type_id() == fp32_2.type_id(), 1); EXPECT_EQ(&fp32_1.abstract_type(),
EXPECT_EQ(&fp32_1.abstract_type() == &ir::AbstractType::lookup(fp32_1.type_id(), ctx));
&ir::AbstractType::lookup(fp32_1.type_id(), ctx),
1);
EXPECT_EQ(ir::Float32Type::classof(fp32_1), 1); EXPECT_EQ(ir::Float32Type::classof(fp32_1), 1);
ir::Type fp64_1 = ir::Float64Type::get(ctx);
ir::Type fp64_2 = ir::Float64Type::get(ctx);
EXPECT_EQ(fp64_1, fp64_2);
EXPECT_EQ(fp64_1.type_id(), fp64_2.type_id());
EXPECT_EQ(&fp64_1.abstract_type(),
&ir::AbstractType::lookup(fp64_1.type_id(), ctx));
EXPECT_EQ(ir::Float64Type::classof(fp64_1), 1);
ir::Type int16_1 = ir::Int16Type::get(ctx);
ir::Type int16_2 = ir::Int16Type::get(ctx);
EXPECT_EQ(int16_1, int16_2);
EXPECT_EQ(int16_1.type_id(), int16_2.type_id());
EXPECT_EQ(&int16_1.abstract_type(),
&ir::AbstractType::lookup(int16_1.type_id(), ctx));
EXPECT_EQ(ir::Int16Type::classof(int16_1), 1);
ir::Type int32_1 = ir::Int32Type::get(ctx); ir::Type int32_1 = ir::Int32Type::get(ctx);
ir::Type int32_2 = ir::Int32Type::get(ctx); ir::Type int32_2 = ir::Int32Type::get(ctx);
EXPECT_EQ(int32_1 == int32_2, 1); EXPECT_EQ(int32_1, int32_2);
EXPECT_EQ(int32_1.type_id() == int32_2.type_id(), 1); EXPECT_EQ(int32_1.type_id(), int32_2.type_id());
EXPECT_EQ(&int32_1.abstract_type() == EXPECT_EQ(&int32_1.abstract_type(),
&ir::AbstractType::lookup(int32_1.type_id(), ctx), &ir::AbstractType::lookup(int32_1.type_id(), ctx));
1);
EXPECT_EQ(ir::Int32Type::classof(int32_1), 1); EXPECT_EQ(ir::Int32Type::classof(int32_1), 1);
ir::Type int64_1 = ir::Int64Type::get(ctx);
ir::Type int64_2 = ir::Int64Type::get(ctx);
EXPECT_EQ(int64_1, int64_2);
EXPECT_EQ(int64_1.type_id(), int64_2.type_id());
EXPECT_EQ(&int64_1.abstract_type(),
&ir::AbstractType::lookup(int64_1.type_id(), ctx));
EXPECT_EQ(ir::Int64Type::classof(int64_1), 1);
// Test 2: Test the parameteric built-in type of IrContext.
ir::DenseTensorTypeStorage::Dim dims = {1, 2, 3};
ir::DenseTensorTypeStorage::DataLayout data_layout =
ir::DenseTensorTypeStorage::DataLayout::NCHW;
ir::DenseTensorTypeStorage::LoD lod = {{1, 2, 3}, {4, 5, 6}};
size_t offset = 0;
ir::Type dense_tensor_1 =
ir::DenseTensorType::get(ctx, fp32_1, dims, data_layout, lod, offset);
ir::Type dense_tensor_2 =
ir::DenseTensorType::get(ctx, fp32_2, dims, data_layout, lod, offset);
ir::Type dense_tensor_3 =
ir::DenseTensorType::get(ctx, fp32_1, dims, data_layout, lod, 2);
EXPECT_EQ(dense_tensor_1, dense_tensor_2);
EXPECT_NE(dense_tensor_1, dense_tensor_3);
EXPECT_EQ(dense_tensor_1.type_id(), dense_tensor_2.type_id());
EXPECT_EQ(ir::DenseTensorType::classof(dense_tensor_1), 1);
ir::DenseTensorType dense_tensor_4 =
ir::DenseTensorType::get(ctx, fp32_1, dims, data_layout, lod, 2);
EXPECT_EQ(dense_tensor_4.offset() == 2, 1);
EXPECT_EQ(dense_tensor_4.dtype().isa<ir::Float32Type>(), true);
EXPECT_EQ(dense_tensor_4.data_layout(), data_layout);
// Test 3: Test isa and dyn_cast.
EXPECT_EQ(fp16_1.isa<ir::Float16Type>(), true);
EXPECT_EQ(fp16_1.isa<ir::Float32Type>(), false);
EXPECT_EQ(fp16_1.isa<ir::DenseTensorType>(), false);
EXPECT_EQ(fp16_1.isa<ir::Type>(), true);
EXPECT_EQ(dense_tensor_1.isa<ir::DenseTensorType>(), true);
ir::DenseTensorType dense_tensor_cast_1 =
dense_tensor_1.dyn_cast<ir::DenseTensorType>();
EXPECT_EQ(dense_tensor_cast_1.isa<ir::DenseTensorType>(), true);
EXPECT_EQ(dense_tensor_cast_1.offset() == 0, 1);
const ir::DenseTensorType dense_tensor_cast_2 =
ir::dyn_cast<ir::DenseTensorType>(dense_tensor_1);
EXPECT_EQ(dense_tensor_cast_2.isa<ir::DenseTensorType>(), true);
EXPECT_EQ(dense_tensor_cast_2.offset() == 0, 1);
} }
// Customize a parameterized TypeStorage IntegerTypeStorage. // Customize a parameterized TypeStorage IntegerTypeStorage.
...@@ -150,15 +220,15 @@ TEST(type_test, custom_type_dialect) { ...@@ -150,15 +220,15 @@ TEST(type_test, custom_type_dialect) {
ir::Type int1_1 = IntegerType::get(ctx, 1, 0); ir::Type int1_1 = IntegerType::get(ctx, 1, 0);
ir::Type int1_2 = IntegerType::get(ctx, 1, 0); ir::Type int1_2 = IntegerType::get(ctx, 1, 0);
EXPECT_EQ(int1_1 == int1_2, 1); EXPECT_EQ(int1_1, int1_2);
ir::Type int8 = IntegerType::get(ctx, 8, 0); ir::Type int8 = IntegerType::get(ctx, 8, 0);
EXPECT_EQ(int8 == int1_2, 0); EXPECT_NE(int8, int1_2);
// Test 2: Test Dialect interfaces // Test 2: Test Dialect interfaces
EXPECT_EQ(ctx == int8.ir_context(), 1); EXPECT_EQ(ctx, int8.ir_context());
EXPECT_EQ(int8.dialect().id() == ir::TypeId::get<IntegerDialect>(), 1); EXPECT_EQ(int8.dialect().id(), ir::TypeId::get<IntegerDialect>());
std::vector<ir::Dialect *> dialect_list = ctx->GetRegisteredDialects(); std::vector<ir::Dialect *> dialect_list = ctx->GetRegisteredDialects();
EXPECT_EQ(dialect_list.size() == 3, 1); // integer, builtin, fake EXPECT_EQ(dialect_list.size() == 3, 1); // integer, builtin, fake
...@@ -166,9 +236,9 @@ TEST(type_test, custom_type_dialect) { ...@@ -166,9 +236,9 @@ TEST(type_test, custom_type_dialect) {
ir::Dialect *dialect_builtin1 = ctx->GetRegisteredDialect("builtin"); ir::Dialect *dialect_builtin1 = ctx->GetRegisteredDialect("builtin");
ir::Dialect *dialect_builtin2 = ir::Dialect *dialect_builtin2 =
ctx->GetRegisteredDialect<ir::BuiltinDialect>(); ctx->GetRegisteredDialect<ir::BuiltinDialect>();
EXPECT_EQ(dialect_builtin1 == dialect_builtin2, 1); EXPECT_EQ(dialect_builtin1, dialect_builtin2);
ir::Dialect *dialect_integer1 = ctx->GetRegisteredDialect("integer"); ir::Dialect *dialect_integer1 = ctx->GetRegisteredDialect("integer");
ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect<IntegerDialect>(); ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect<IntegerDialect>();
EXPECT_EQ(dialect_integer1 == dialect_integer2, 1); EXPECT_EQ(dialect_integer1, dialect_integer2);
} }
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/ir/cast_utils.h"
#include "paddle/ir/type_base.h" #include "paddle/ir/type_base.h"
namespace ir { namespace ir {
...@@ -37,15 +38,19 @@ class Type { ...@@ -37,15 +38,19 @@ class Type {
Type &operator=(const Type &other) = default; Type &operator=(const Type &other) = default;
/// ///
/// \brief Comparison operations. /// \brief Some operators are overloaded.
/// ///
bool operator==(Type other) const { return storage_ == other.storage_; } bool operator==(Type other) const { return storage_ == other.storage_; }
bool operator!=(Type other) const { return storage_ != other.storage_; } bool operator!=(Type other) const { return storage_ != other.storage_; }
explicit operator bool() const { return storage_; } explicit operator bool() const { return storage_; }
bool operator!() const { return storage_ == nullptr; } bool operator!() const { return storage_ == nullptr; }
///
/// \brief Some type attribute acquisition interfaces.
///
TypeId type_id() { return storage_->abstract_type().type_id(); } TypeId type_id() { return storage_->abstract_type().type_id(); }
const AbstractType &abstract_type() { return storage_->abstract_type(); } const AbstractType &abstract_type() { return storage_->abstract_type(); }
...@@ -56,6 +61,21 @@ class Type { ...@@ -56,6 +61,21 @@ class Type {
IrContext *ir_context() const; IrContext *ir_context() const;
///
/// \brief Methods for type judgment and cast.
///
static bool classof(Type) { return true; }
template <typename T>
bool isa() const {
return ir::isa<T>(*this);
}
template <typename U>
U dyn_cast() const {
return ir::dyn_cast<U>(*this);
}
/// ///
/// \brief Enable hashing Type. /// \brief Enable hashing Type.
/// ///
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册