diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/utils/attribute_util.h index b5caaf667b9584fa18a429218e113ba27c9df4ed..02e7fc1bf28235f22718137257077b2007e98bc4 100644 --- a/paddle/cinn/utils/attribute_util.h +++ b/paddle/cinn/utils/attribute_util.h @@ -87,6 +87,7 @@ common::Type ConvertIRType(::ir::Type type) { CASE_TYPE(Int16Type, I16) CASE_TYPE(Int32Type, I32) CASE_TYPE(Int64Type, I64) + CASE_TYPE(IndexType, I32) CASE_TYPE(BoolType, UI1) LOG(FATAL) << "unknown ir::Type " << type; diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index a81febc0cbab99a3c193c1002768eccd864d80f9..13a9f3d7ac8b824953fb3d3d57c1958a00f34097 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -66,6 +66,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) { return phi::DataType::INT32; } else if (dtype.isa()) { return phi::DataType::INT64; + } else if (dtype.isa()) { + return phi::DataType::INT32; } else if (dtype.isa()) { return phi::DataType::BOOL; } else if (dtype.isa()) { @@ -79,6 +81,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) { } } +// use phi::DataType::INT32 for IndexType from builtin type to phi::DataType, +// but only use INT32 not IndexType from phi::DataType type to builtin type. static inline ir::Type TransToIrDataType(phi::DataType dtype, ir::IrContext* ctx = nullptr) { if (ctx == nullptr) { diff --git a/paddle/ir/core/builder.cc b/paddle/ir/core/builder.cc index 954b46b08f897f66abdbe19bcb5b37b66fbe97b8..1bfbd2e2a8ca8bb6e46fa3c63d4876519d65c3c2 100644 --- a/paddle/ir/core/builder.cc +++ b/paddle/ir/core/builder.cc @@ -49,6 +49,7 @@ BFloat16Type Builder::bfloat16_type() { return BFloat16Type::get(context_); } Float32Type Builder::float32_type() { return Float32Type::get(context_); } Float64Type Builder::float64_type() { return Float64Type::get(context_); } +IndexType Builder::index_type() { return IndexType::get(context_); } Int16Type Builder::int16_type() { return Int16Type::get(context_); } BoolType Builder::bool_type() { return BoolType::get(context_); } Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); } diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index 74856cdaf7c0ca4c5a43a9a3cd983c4843ec355d..f3ae837ea9723bc208856728fd9afe74eb4d827d 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -29,6 +29,7 @@ class BFloat16Type; class Float32Type; class Float64Type; class Int16Type; +class IndexType; class BoolType; class Complex64Type; class Complex128Type; @@ -114,6 +115,7 @@ class Builder { IR_API Int8Type int8_type(); IR_API VectorType vec_type(const std::vector &); IR_API BFloat16Type bfloat16_type(); + IR_API IndexType index_type(); IR_API Float32Type float32_type(); IR_API Float64Type float64_type(); IR_API Int16Type int16_type(); diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index a5e9605c2835e67126d1d94c06aa7d8f961f3c95..3284a96c8b5193eb831927fbffb14836e24009bf 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -34,6 +34,7 @@ void BuiltinDialect::initialize() { Int16Type, Int32Type, Int64Type, + IndexType, BoolType, Complex64Type, Complex128Type, diff --git a/paddle/ir/core/builtin_type.cc b/paddle/ir/core/builtin_type.cc index 3a8e1030fb07f258a90e5117a35d00cae0a523a1..8a0aea5745a5b2a3535fac5bbdd80d8a07adbd07 100644 --- a/paddle/ir/core/builtin_type.cc +++ b/paddle/ir/core/builtin_type.cc @@ -29,6 +29,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type) diff --git a/paddle/ir/core/builtin_type.h b/paddle/ir/core/builtin_type.h index aa043f206d22e1540bc22d6234581abe84285323..9a2939110deaca170db1cb6618e3d2e2ab9327ab 100644 --- a/paddle/ir/core/builtin_type.h +++ b/paddle/ir/core/builtin_type.h @@ -73,6 +73,7 @@ class IR_API VectorType : public Type { __macro(Int16Type); \ __macro(Int32Type); \ __macro(Int64Type); \ + __macro(IndexType); \ __macro(BoolType); \ __macro(Complex64Type); \ __macro(Complex128Type); @@ -95,5 +96,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type) diff --git a/paddle/ir/core/ir_context.cc b/paddle/ir/core/ir_context.cc index 5c609f183c40d7ac1303b1eff11a1b8da57166d2..54865e1ec38bdc2e2c7c1ff8b8640f1508d63776 100644 --- a/paddle/ir/core/ir_context.cc +++ b/paddle/ir/core/ir_context.cc @@ -156,6 +156,7 @@ class IrContextImpl { Float16Type fp16_type; Float32Type fp32_type; Float64Type fp64_type; + IndexType index_type; UInt8Type uint8_type; Int8Type int8_type; Int16Type int16_type; @@ -203,6 +204,7 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { impl_->int16_type = TypeManager::get(this); impl_->int32_type = TypeManager::get(this); impl_->int64_type = TypeManager::get(this); + impl_->index_type = TypeManager::get(this); impl_->bool_type = TypeManager::get(this); impl_->complex64_type = TypeManager::get(this); impl_->complex128_type = TypeManager::get(this); @@ -343,6 +345,8 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; } +IndexType IndexType::get(IrContext *ctx) { return ctx->impl().index_type; } + Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; } UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; } diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index f8549433f75c70c47c01101e508b3c1e9e54f71e..080e0bafc966a7f4d157661d5083cdffd1d51bed 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -59,6 +59,8 @@ void BasicIrPrinter::PrintType(Type type) { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "index"; } else if (type.isa()) { os << "c64"; } else if (type.isa()) { diff --git a/paddle/ir/core/type.h b/paddle/ir/core/type.h index 62dcefdf3ba65182272edf2987aa1ef2ba55a257..df148f17a2350661888edee8e3c7f16517a23e1f 100644 --- a/paddle/ir/core/type.h +++ b/paddle/ir/core/type.h @@ -97,6 +97,27 @@ IR_API std::ostream &operator<<(std::ostream &os, Type type); } // namespace ir +/// +/// \brief This class represents the base of a type interface. +/// + +// template +// class TypeInterface : public ir::DialectInterface { +// public: +// using Base = TypeInterface; +// using DialectInterfaceBase = ir::DialectInterface; +// using DialectInterfaceBase::Base; + +// private: +// /// Returns the impl interface instance for the given type. +// static typename InterfaceBase::Concept *getInterfaceFor(Type type) { +// return type.getAbstractType().getInterface(); +// } + +// /// Allow access to 'getInterfaceFor'. +// friend InterfaceBase; +// }; + namespace std { /// /// \brief Enable hashing Type. diff --git a/test/cpp/ir/core/ir_builder_test.cc b/test/cpp/ir/core/ir_builder_test.cc index 3b70220a8d309656924cf8b2165cf9615db9c742..863bac72da9c2600051c3dfb41b17c3689517058 100644 --- a/test/cpp/ir/core/ir_builder_test.cc +++ b/test/cpp/ir/core/ir_builder_test.cc @@ -31,6 +31,7 @@ TEST(builder_test, type_api) { EXPECT_EQ(ir::BFloat16Type::get(&ctx), builder.bfloat16_type()); EXPECT_EQ(ir::Float32Type::get(&ctx), builder.float32_type()); EXPECT_EQ(ir::Float64Type::get(&ctx), builder.float64_type()); + EXPECT_EQ(ir::IndexType::get(&ctx), builder.index_type()); EXPECT_EQ(ir::Int16Type::get(&ctx), builder.int16_type()); EXPECT_EQ(ir::BoolType::get(&ctx), builder.bool_type()); EXPECT_EQ(ir::Complex64Type::get(&ctx), builder.complex64_type()); diff --git a/test/cpp/ir/core/ir_type_converter_test.cc b/test/cpp/ir/core/ir_type_converter_test.cc index 896c1059dc664496467671f3520a54ca1f2e318a..26f4cde5891719d983d7e24b061154661a1a9a7e 100644 --- a/test/cpp/ir/core/ir_type_converter_test.cc +++ b/test/cpp/ir/core/ir_type_converter_test.cc @@ -65,3 +65,23 @@ TEST(TypeConverterTest, paramterless_type) { ir::Complex64Type, ir::Complex128Type>(); } + +void test_index_type() { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + ir::Type type = ir::IndexType::get(ctx); + std::stringstream ss; + ss << type; + EXPECT_GT(ss.str().size(), 0u); + EXPECT_EQ(ss.str(), "index"); + EXPECT_NE(ss.str(), "<>"); + phi::DataType phi_type = paddle::dialect::TransToPhiDataType(type); + auto& type_translator = paddle::translator::TypeTranslator::instance(); + paddle::framework::VarDesc empty_var_desc("empty"); + auto proto_type = paddle::framework::TransToProtoVarType(phi_type); + ir::Type final_type = type_translator[proto_type](ctx, empty_var_desc); + EXPECT_EQ(paddle::dialect::TransToIrDataType(phi_type), final_type); +} + +TEST(IndexTypeConverterTest, index_type) { test_index_type(); } diff --git a/test/cpp/ir/core/type_test.cc b/test/cpp/ir/core/type_test.cc index a748e1d5db88b8bb618208e4f038de260aba57e4..24bf92446c2a043a54b41bce7c1431dc57b90d0e 100644 --- a/test/cpp/ir/core/type_test.cc +++ b/test/cpp/ir/core/type_test.cc @@ -89,6 +89,14 @@ TEST(type_test, built_in_type) { &ir::AbstractType::lookup(bfp16_1.type_id(), ctx)); EXPECT_EQ(ir::BFloat16Type::classof(bfp16_1), 1); + ir::Type index_1 = ir::IndexType::get(ctx); + ir::Type index_2 = ir::IndexType::get(ctx); + EXPECT_EQ(index_1, index_2); + EXPECT_EQ(index_1.type_id(), index_2.type_id()); + EXPECT_EQ(&index_1.abstract_type(), + &ir::AbstractType::lookup(index_1.type_id(), ctx)); + EXPECT_EQ(ir::IndexType::classof(index_1), 1); + ir::Type fp16_1 = ir::Float16Type::get(ctx); ir::Type fp16_2 = ir::Float16Type::get(ctx); EXPECT_EQ(fp16_1, fp16_2);