diff --git a/paddle/ir/core/builtin_type.h b/paddle/ir/core/builtin_type.h index 6973b9a92d34cb1d2b270bd999514fbb5b1caad0..1e8208708305f3e0751729ff0f5346197ba79f87 100644 --- a/paddle/ir/core/builtin_type.h +++ b/paddle/ir/core/builtin_type.h @@ -35,16 +35,8 @@ namespace ir { /// \endcode /// -// NOTE(dev): Currently BF16 and Int8 are not considered as a cached member -// in IrContextImpl because they are not widely used. -class BFloat16Type : public Type { - public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(BFloat16Type, TypeStorage); - - static BFloat16Type get(IrContext *context); -}; +// NOTE(dev): Currently Int8 are not considered as a cached member +// in IrContextImpl because it is not widely used. class Int8Type : public Type { public: @@ -79,6 +71,7 @@ class VectorType : public Type { }; #define FOREACH_BUILTIN_TYPE(__macro) \ + __macro(BFloat16); \ __macro(Float16); \ __macro(Float32); \ __macro(Float64); \ diff --git a/test/cpp/ir/core/type_test.cc b/test/cpp/ir/core/type_test.cc index 8a7d4a9039d4c40b2bc11eb5d3d264b89ffc8774..e3c9bcd20dd63659bd4b194ed3b325845ce5f37c 100644 --- a/test/cpp/ir/core/type_test.cc +++ b/test/cpp/ir/core/type_test.cc @@ -77,6 +77,14 @@ TEST(type_test, built_in_type) { ir::IrContext *ctx = ir::IrContext::Instance(); // Test 1: Test the parameterless built-in type of IrContext. + ir::Type bfp16_1 = ir::BFloat16Type::get(ctx); + ir::Type bfp16_2 = ir::BFloat16Type::get(ctx); + EXPECT_EQ(bfp16_1, bfp16_2); + EXPECT_EQ(bfp16_1.type_id(), bfp16_2.type_id()); + EXPECT_EQ(&bfp16_1.abstract_type(), + &ir::AbstractType::lookup(bfp16_1.type_id(), ctx)); + EXPECT_EQ(ir::BFloat16Type::classof(bfp16_1), 1); + ir::Type fp16_1 = ir::Float16Type::get(ctx); ir::Type fp16_2 = ir::Float16Type::get(ctx); EXPECT_EQ(fp16_1, fp16_2);