From 34e905744eaf4d0fa20592066ae66030bbb61a75 Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:40:46 +0800 Subject: [PATCH] add indextype (#56112) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit IR 的 builtin dialect 中加入 IndexType --- paddle/cinn/utils/attribute_util.h | 1 + paddle/fluid/ir/dialect/utils.h | 4 ++++ paddle/ir/core/builder.cc | 1 + paddle/ir/core/builder.h | 2 ++ paddle/ir/core/builtin_dialect.cc | 1 + paddle/ir/core/builtin_type.cc | 1 + paddle/ir/core/builtin_type.h | 2 ++ paddle/ir/core/ir_context.cc | 4 ++++ paddle/ir/core/ir_printer.cc | 2 ++ paddle/ir/core/type.h | 21 +++++++++++++++++++++ test/cpp/ir/core/ir_builder_test.cc | 1 + test/cpp/ir/core/ir_type_converter_test.cc | 20 ++++++++++++++++++++ test/cpp/ir/core/type_test.cc | 8 ++++++++ 13 files changed, 68 insertions(+) diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/utils/attribute_util.h index b5caaf667b9..02e7fc1bf28 100644 --- a/paddle/cinn/utils/attribute_util.h +++ b/paddle/cinn/utils/attribute_util.h @@ -87,6 +87,7 @@ common::Type ConvertIRType(::ir::Type type) { CASE_TYPE(Int16Type, I16) CASE_TYPE(Int32Type, I32) CASE_TYPE(Int64Type, I64) + CASE_TYPE(IndexType, I32) CASE_TYPE(BoolType, UI1) LOG(FATAL) << "unknown ir::Type " << type; diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index a81febc0cba..13a9f3d7ac8 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -66,6 +66,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) { return phi::DataType::INT32; } else if (dtype.isa()) { return phi::DataType::INT64; + } else if (dtype.isa()) { + return phi::DataType::INT32; } else if (dtype.isa()) { return phi::DataType::BOOL; } else if (dtype.isa()) { @@ -79,6 +81,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) { } } +// use phi::DataType::INT32 for IndexType from builtin type to phi::DataType, +// but only use INT32 not IndexType from phi::DataType type to builtin type. static inline ir::Type TransToIrDataType(phi::DataType dtype, ir::IrContext* ctx = nullptr) { if (ctx == nullptr) { diff --git a/paddle/ir/core/builder.cc b/paddle/ir/core/builder.cc index 954b46b08f8..1bfbd2e2a8c 100644 --- a/paddle/ir/core/builder.cc +++ b/paddle/ir/core/builder.cc @@ -49,6 +49,7 @@ BFloat16Type Builder::bfloat16_type() { return BFloat16Type::get(context_); } Float32Type Builder::float32_type() { return Float32Type::get(context_); } Float64Type Builder::float64_type() { return Float64Type::get(context_); } +IndexType Builder::index_type() { return IndexType::get(context_); } Int16Type Builder::int16_type() { return Int16Type::get(context_); } BoolType Builder::bool_type() { return BoolType::get(context_); } Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); } diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index 74856cdaf7c..f3ae837ea97 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -29,6 +29,7 @@ class BFloat16Type; class Float32Type; class Float64Type; class Int16Type; +class IndexType; class BoolType; class Complex64Type; class Complex128Type; @@ -114,6 +115,7 @@ class Builder { IR_API Int8Type int8_type(); IR_API VectorType vec_type(const std::vector &); IR_API BFloat16Type bfloat16_type(); + IR_API IndexType index_type(); IR_API Float32Type float32_type(); IR_API Float64Type float64_type(); IR_API Int16Type int16_type(); diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index a5e9605c283..3284a96c8b5 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -34,6 +34,7 @@ void BuiltinDialect::initialize() { Int16Type, Int32Type, Int64Type, + IndexType, BoolType, Complex64Type, Complex128Type, diff --git a/paddle/ir/core/builtin_type.cc b/paddle/ir/core/builtin_type.cc index 3a8e1030fb0..8a0aea5745a 100644 --- a/paddle/ir/core/builtin_type.cc +++ b/paddle/ir/core/builtin_type.cc @@ -29,6 +29,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type) diff --git a/paddle/ir/core/builtin_type.h b/paddle/ir/core/builtin_type.h index aa043f206d2..9a2939110de 100644 --- a/paddle/ir/core/builtin_type.h +++ b/paddle/ir/core/builtin_type.h @@ -73,6 +73,7 @@ class IR_API VectorType : public Type { __macro(Int16Type); \ __macro(Int32Type); \ __macro(Int64Type); \ + __macro(IndexType); \ __macro(BoolType); \ __macro(Complex64Type); \ __macro(Complex128Type); @@ -95,5 +96,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type) diff --git a/paddle/ir/core/ir_context.cc b/paddle/ir/core/ir_context.cc index 5c609f183c4..54865e1ec38 100644 --- a/paddle/ir/core/ir_context.cc +++ b/paddle/ir/core/ir_context.cc @@ -156,6 +156,7 @@ class IrContextImpl { Float16Type fp16_type; Float32Type fp32_type; Float64Type fp64_type; + IndexType index_type; UInt8Type uint8_type; Int8Type int8_type; Int16Type int16_type; @@ -203,6 +204,7 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { impl_->int16_type = TypeManager::get(this); impl_->int32_type = TypeManager::get(this); impl_->int64_type = TypeManager::get(this); + impl_->index_type = TypeManager::get(this); impl_->bool_type = TypeManager::get(this); impl_->complex64_type = TypeManager::get(this); impl_->complex128_type = TypeManager::get(this); @@ -343,6 +345,8 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; } +IndexType IndexType::get(IrContext *ctx) { return ctx->impl().index_type; } + Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; } UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; } diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index f8549433f75..080e0bafc96 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -59,6 +59,8 @@ void BasicIrPrinter::PrintType(Type type) { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "index"; } else if (type.isa()) { os << "c64"; } else if (type.isa()) { diff --git a/paddle/ir/core/type.h b/paddle/ir/core/type.h index 62dcefdf3ba..df148f17a23 100644 --- a/paddle/ir/core/type.h +++ b/paddle/ir/core/type.h @@ -97,6 +97,27 @@ IR_API std::ostream &operator<<(std::ostream &os, Type type); } // namespace ir +/// +/// \brief This class represents the base of a type interface. +/// + +// template +// class TypeInterface : public ir::DialectInterface { +// public: +// using Base = TypeInterface; +// using DialectInterfaceBase = ir::DialectInterface; +// using DialectInterfaceBase::Base; + +// private: +// /// Returns the impl interface instance for the given type. +// static typename InterfaceBase::Concept *getInterfaceFor(Type type) { +// return type.getAbstractType().getInterface(); +// } + +// /// Allow access to 'getInterfaceFor'. +// friend InterfaceBase; +// }; + namespace std { /// /// \brief Enable hashing Type. diff --git a/test/cpp/ir/core/ir_builder_test.cc b/test/cpp/ir/core/ir_builder_test.cc index 3b70220a8d3..863bac72da9 100644 --- a/test/cpp/ir/core/ir_builder_test.cc +++ b/test/cpp/ir/core/ir_builder_test.cc @@ -31,6 +31,7 @@ TEST(builder_test, type_api) { EXPECT_EQ(ir::BFloat16Type::get(&ctx), builder.bfloat16_type()); EXPECT_EQ(ir::Float32Type::get(&ctx), builder.float32_type()); EXPECT_EQ(ir::Float64Type::get(&ctx), builder.float64_type()); + EXPECT_EQ(ir::IndexType::get(&ctx), builder.index_type()); EXPECT_EQ(ir::Int16Type::get(&ctx), builder.int16_type()); EXPECT_EQ(ir::BoolType::get(&ctx), builder.bool_type()); EXPECT_EQ(ir::Complex64Type::get(&ctx), builder.complex64_type()); diff --git a/test/cpp/ir/core/ir_type_converter_test.cc b/test/cpp/ir/core/ir_type_converter_test.cc index 896c1059dc6..26f4cde5891 100644 --- a/test/cpp/ir/core/ir_type_converter_test.cc +++ b/test/cpp/ir/core/ir_type_converter_test.cc @@ -65,3 +65,23 @@ TEST(TypeConverterTest, paramterless_type) { ir::Complex64Type, ir::Complex128Type>(); } + +void test_index_type() { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + ir::Type type = ir::IndexType::get(ctx); + std::stringstream ss; + ss << type; + EXPECT_GT(ss.str().size(), 0u); + EXPECT_EQ(ss.str(), "index"); + EXPECT_NE(ss.str(), "<>"); + phi::DataType phi_type = paddle::dialect::TransToPhiDataType(type); + auto& type_translator = paddle::translator::TypeTranslator::instance(); + paddle::framework::VarDesc empty_var_desc("empty"); + auto proto_type = paddle::framework::TransToProtoVarType(phi_type); + ir::Type final_type = type_translator[proto_type](ctx, empty_var_desc); + EXPECT_EQ(paddle::dialect::TransToIrDataType(phi_type), final_type); +} + +TEST(IndexTypeConverterTest, index_type) { test_index_type(); } diff --git a/test/cpp/ir/core/type_test.cc b/test/cpp/ir/core/type_test.cc index a748e1d5db8..24bf92446c2 100644 --- a/test/cpp/ir/core/type_test.cc +++ b/test/cpp/ir/core/type_test.cc @@ -89,6 +89,14 @@ TEST(type_test, built_in_type) { &ir::AbstractType::lookup(bfp16_1.type_id(), ctx)); EXPECT_EQ(ir::BFloat16Type::classof(bfp16_1), 1); + ir::Type index_1 = ir::IndexType::get(ctx); + ir::Type index_2 = ir::IndexType::get(ctx); + EXPECT_EQ(index_1, index_2); + EXPECT_EQ(index_1.type_id(), index_2.type_id()); + EXPECT_EQ(&index_1.abstract_type(), + &ir::AbstractType::lookup(index_1.type_id(), ctx)); + EXPECT_EQ(ir::IndexType::classof(index_1), 1); + ir::Type fp16_1 = ir::Float16Type::get(ctx); ir::Type fp16_2 = ir::Float16Type::get(ctx); EXPECT_EQ(fp16_1, fp16_2); -- GitLab