未验证 提交 34e90574 编写于 作者: B Bo Zhang 提交者: GitHub

add indextype (#56112)

IR 的 builtin dialect 中加入 IndexType
上级 74eb3093
......@@ -87,6 +87,7 @@ common::Type ConvertIRType(::ir::Type type) {
CASE_TYPE(Int16Type, I16)
CASE_TYPE(Int32Type, I32)
CASE_TYPE(Int64Type, I64)
CASE_TYPE(IndexType, I32)
CASE_TYPE(BoolType, UI1)
LOG(FATAL) << "unknown ir::Type " << type;
......
......@@ -66,6 +66,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
return phi::DataType::INT32;
} else if (dtype.isa<ir::Int64Type>()) {
return phi::DataType::INT64;
} else if (dtype.isa<ir::IndexType>()) {
return phi::DataType::INT32;
} else if (dtype.isa<ir::BoolType>()) {
return phi::DataType::BOOL;
} else if (dtype.isa<ir::Complex64Type>()) {
......@@ -79,6 +81,8 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
}
}
// use phi::DataType::INT32 for IndexType from builtin type to phi::DataType,
// but only use INT32 not IndexType from phi::DataType type to builtin type.
static inline ir::Type TransToIrDataType(phi::DataType dtype,
ir::IrContext* ctx = nullptr) {
if (ctx == nullptr) {
......
......@@ -49,6 +49,7 @@ BFloat16Type Builder::bfloat16_type() { return BFloat16Type::get(context_); }
Float32Type Builder::float32_type() { return Float32Type::get(context_); }
Float64Type Builder::float64_type() { return Float64Type::get(context_); }
IndexType Builder::index_type() { return IndexType::get(context_); }
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
BoolType Builder::bool_type() { return BoolType::get(context_); }
Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); }
......
......@@ -29,6 +29,7 @@ class BFloat16Type;
class Float32Type;
class Float64Type;
class Int16Type;
class IndexType;
class BoolType;
class Complex64Type;
class Complex128Type;
......@@ -114,6 +115,7 @@ class Builder {
IR_API Int8Type int8_type();
IR_API VectorType vec_type(const std::vector<Type> &);
IR_API BFloat16Type bfloat16_type();
IR_API IndexType index_type();
IR_API Float32Type float32_type();
IR_API Float64Type float64_type();
IR_API Int16Type int16_type();
......
......@@ -34,6 +34,7 @@ void BuiltinDialect::initialize() {
Int16Type,
Int32Type,
Int64Type,
IndexType,
BoolType,
Complex64Type,
Complex128Type,
......
......@@ -29,6 +29,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float64Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type)
......@@ -73,6 +73,7 @@ class IR_API VectorType : public Type {
__macro(Int16Type); \
__macro(Int32Type); \
__macro(Int64Type); \
__macro(IndexType); \
__macro(BoolType); \
__macro(Complex64Type); \
__macro(Complex128Type);
......@@ -95,5 +96,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type)
......@@ -156,6 +156,7 @@ class IrContextImpl {
Float16Type fp16_type;
Float32Type fp32_type;
Float64Type fp64_type;
IndexType index_type;
UInt8Type uint8_type;
Int8Type int8_type;
Int16Type int16_type;
......@@ -203,6 +204,7 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
impl_->int16_type = TypeManager::get<Int16Type>(this);
impl_->int32_type = TypeManager::get<Int32Type>(this);
impl_->int64_type = TypeManager::get<Int64Type>(this);
impl_->index_type = TypeManager::get<IndexType>(this);
impl_->bool_type = TypeManager::get<BoolType>(this);
impl_->complex64_type = TypeManager::get<Complex64Type>(this);
impl_->complex128_type = TypeManager::get<Complex128Type>(this);
......@@ -343,6 +345,8 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }
IndexType IndexType::get(IrContext *ctx) { return ctx->impl().index_type; }
Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; }
UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; }
......
......@@ -59,6 +59,8 @@ void BasicIrPrinter::PrintType(Type type) {
os << "i32";
} else if (type.isa<Int64Type>()) {
os << "i64";
} else if (type.isa<IndexType>()) {
os << "index";
} else if (type.isa<Complex64Type>()) {
os << "c64";
} else if (type.isa<Complex128Type>()) {
......
......@@ -97,6 +97,27 @@ IR_API std::ostream &operator<<(std::ostream &os, Type type);
} // namespace ir
///
/// \brief This class represents the base of a type interface.
///
// template <typename ConcreteType>
// class TypeInterface : public ir::DialectInterface<ConcreteType, Type> {
// public:
// using Base = TypeInterface<ConcreteType>;
// using DialectInterfaceBase = ir::DialectInterface<ConcreteType, Type>;
// using DialectInterfaceBase::Base;
// private:
// /// Returns the impl interface instance for the given type.
// static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
// return type.getAbstractType().getInterface<ConcreteType>();
// }
// /// Allow access to 'getInterfaceFor'.
// friend InterfaceBase;
// };
namespace std {
///
/// \brief Enable hashing Type.
......
......@@ -31,6 +31,7 @@ TEST(builder_test, type_api) {
EXPECT_EQ(ir::BFloat16Type::get(&ctx), builder.bfloat16_type());
EXPECT_EQ(ir::Float32Type::get(&ctx), builder.float32_type());
EXPECT_EQ(ir::Float64Type::get(&ctx), builder.float64_type());
EXPECT_EQ(ir::IndexType::get(&ctx), builder.index_type());
EXPECT_EQ(ir::Int16Type::get(&ctx), builder.int16_type());
EXPECT_EQ(ir::BoolType::get(&ctx), builder.bool_type());
EXPECT_EQ(ir::Complex64Type::get(&ctx), builder.complex64_type());
......
......@@ -65,3 +65,23 @@ TEST(TypeConverterTest, paramterless_type) {
ir::Complex64Type,
ir::Complex128Type>();
}
void test_index_type() {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
ir::Type type = ir::IndexType::get(ctx);
std::stringstream ss;
ss << type;
EXPECT_GT(ss.str().size(), 0u);
EXPECT_EQ(ss.str(), "index");
EXPECT_NE(ss.str(), "<<NULL TYPE>>");
phi::DataType phi_type = paddle::dialect::TransToPhiDataType(type);
auto& type_translator = paddle::translator::TypeTranslator::instance();
paddle::framework::VarDesc empty_var_desc("empty");
auto proto_type = paddle::framework::TransToProtoVarType(phi_type);
ir::Type final_type = type_translator[proto_type](ctx, empty_var_desc);
EXPECT_EQ(paddle::dialect::TransToIrDataType(phi_type), final_type);
}
TEST(IndexTypeConverterTest, index_type) { test_index_type(); }
......@@ -89,6 +89,14 @@ TEST(type_test, built_in_type) {
&ir::AbstractType::lookup(bfp16_1.type_id(), ctx));
EXPECT_EQ(ir::BFloat16Type::classof(bfp16_1), 1);
ir::Type index_1 = ir::IndexType::get(ctx);
ir::Type index_2 = ir::IndexType::get(ctx);
EXPECT_EQ(index_1, index_2);
EXPECT_EQ(index_1.type_id(), index_2.type_id());
EXPECT_EQ(&index_1.abstract_type(),
&ir::AbstractType::lookup(index_1.type_id(), ctx));
EXPECT_EQ(ir::IndexType::classof(index_1), 1);
ir::Type fp16_1 = ir::Float16Type::get(ctx);
ir::Type fp16_2 = ir::Float16Type::get(ctx);
EXPECT_EQ(fp16_1, fp16_2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册