From ffc1b027d1b546c08260186685780a2cbfbdb240 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 28 Jun 2023 10:51:55 +0800 Subject: [PATCH] [IR] complement ir type (#54911) * complement ir type * fix ir_printer --- paddle/fluid/ir/dialect/utils.h | 26 ++++++- .../ir_adaptor/translator/op_translator.cc | 2 + .../ir_adaptor/translator/type_translator.cc | 38 ++++++++++- paddle/ir/core/builtin_dialect.cc | 3 + paddle/ir/core/builtin_type.cc | 3 + paddle/ir/core/builtin_type.h | 16 ++--- paddle/ir/core/ir_context.cc | 24 +++++++ paddle/ir/core/ir_printer.cc | 14 +++- test/cpp/ir/core/CMakeLists.txt | 10 +++ test/cpp/ir/core/ir_type_converter_test.cc | 67 +++++++++++++++++++ 10 files changed, 192 insertions(+), 11 deletions(-) create mode 100644 test/cpp/ir/core/ir_type_converter_test.cc diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index bf666ad01b6..0cdf4ef4962 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -26,18 +26,30 @@ namespace dialect { // TODO(zhangbo): The builtin type needs to cover all data types of // phi::DataType. static inline phi::DataType TransToPhiDataType(ir::Type dtype) { - if (dtype.isa()) { + if (dtype.isa()) { + return phi::DataType::BFLOAT16; + } else if (dtype.isa()) { return phi::DataType::FLOAT16; } else if (dtype.isa()) { return phi::DataType::FLOAT32; } else if (dtype.isa()) { return phi::DataType::FLOAT64; + } else if (dtype.isa()) { + return phi::DataType::UINT8; + } else if (dtype.isa()) { + return phi::DataType::INT8; } else if (dtype.isa()) { return phi::DataType::INT16; } else if (dtype.isa()) { return phi::DataType::INT32; } else if (dtype.isa()) { return phi::DataType::INT64; + } else if (dtype.isa()) { + return phi::DataType::BOOL; + } else if (dtype.isa()) { + return phi::DataType::COMPLEX64; + } else if (dtype.isa()) { + return phi::DataType::COMPLEX128; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir data type when casting it into " @@ -51,18 +63,30 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype, ctx = ir::IrContext::Instance(); } switch (dtype) { + case phi::DataType::BFLOAT16: + return ir::BFloat16Type::get(ctx); case phi::DataType::FLOAT16: return ir::Float16Type::get(ctx); case phi::DataType::FLOAT32: return ir::Float32Type::get(ctx); case phi::DataType::FLOAT64: return ir::Float64Type::get(ctx); + case phi::DataType::UINT8: + return ir::UInt8Type::get(ctx); + case phi::DataType::INT8: + return ir::Int8Type::get(ctx); case phi::DataType::INT16: return ir::Int16Type::get(ctx); case phi::DataType::INT32: return ir::Int32Type::get(ctx); case phi::DataType::INT64: return ir::Int64Type::get(ctx); + case phi::DataType::BOOL: + return ir::BoolType::get(ctx); + case phi::DataType::COMPLEX64: + return ir::Complex64Type::get(ctx); + case phi::DataType::COMPLEX128: + return ir::Complex128Type::get(ctx); default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported phi data type `%s` when casting it into " diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 58c27c89ebc..a94abc9a81f 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -327,6 +327,8 @@ inline std::vector GenerateOperationInput( } bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + is_vector |= + (info.type_name.find("IntArrayAttribute") != std::string::npos); VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " << is_vector << " " << info.type_name; diff --git a/paddle/fluid/ir_adaptor/translator/type_translator.cc b/paddle/fluid/ir_adaptor/translator/type_translator.cc index 7e57216533a..231eeefbe0c 100644 --- a/paddle/fluid/ir_adaptor/translator/type_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/type_translator.cc @@ -31,10 +31,34 @@ using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; TypeTranslator::TypeTranslator() { handlers = { + {VarType::BOOL, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::BoolType::get(ctx); + }}, + {VarType::UINT8, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::UInt8Type::get(ctx); + }}, + {VarType::INT8, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Int8Type::get(ctx); + }}, + {VarType::INT16, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Int16Type::get(ctx); + }}, + {VarType::INT32, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Int32Type::get(ctx); + }}, {VarType::INT64, [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { return ir::Int64Type::get(ctx); }}, + {VarType::FP16, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Float16Type::get(ctx); + }}, {VarType::FP32, [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { return ir::Float32Type::get(ctx); @@ -43,10 +67,22 @@ TypeTranslator::TypeTranslator() { [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { return ir::Float64Type::get(ctx); }}, + {VarType::BF16, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::BFloat16Type::get(ctx); + }}, + {VarType::COMPLEX64, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Complex64Type::get(ctx); + }}, + {VarType::COMPLEX128, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Complex128Type::get(ctx); + }}, {VarType::LOD_TENSOR, [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { VLOG(10) << "[vartype translating]" - << "[" << var_desc.Name() << "]" << var_desc.GetDataType(); + << "[" << var_desc.Name() << "] from LOD_TENSOR"; ir::Type dtype = this->operator[](var_desc.GetDataType())(ctx, var_desc); diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index 2766be29f91..2dc4438564b 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -30,10 +30,13 @@ void BuiltinDialect::initialize() { Float32Type, Float64Type, Int8Type, + UInt8Type, Int16Type, Int32Type, Int64Type, BoolType, + Complex64Type, + Complex128Type, VectorType>(); RegisterAttributes VectorType::data() const { return storage()->GetAsKey(); } } // namespace ir +IR_DEFINE_EXPLICIT_TYPE_ID(ir::UInt8Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int8Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::VectorType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BFloat16Type) @@ -29,3 +30,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type) diff --git a/paddle/ir/core/builtin_type.h b/paddle/ir/core/builtin_type.h index ed09254f510..aa043f206d2 100644 --- a/paddle/ir/core/builtin_type.h +++ b/paddle/ir/core/builtin_type.h @@ -38,13 +38,6 @@ namespace ir { // NOTE(dev): Currently Int8 are not considered as a cached member // in IrContextImpl because it is not widely used. -class IR_API Int8Type : public Type { - public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(Int8Type, TypeStorage); -}; - class IR_API VectorType : public Type { public: using Type::Type; @@ -75,10 +68,14 @@ class IR_API VectorType : public Type { __macro(Float16Type); \ __macro(Float32Type); \ __macro(Float64Type); \ + __macro(Int8Type); \ + __macro(UInt8Type); \ __macro(Int16Type); \ __macro(Int32Type); \ __macro(Int64Type); \ - __macro(BoolType); + __macro(BoolType); \ + __macro(Complex64Type); \ + __macro(Complex128Type); FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE) @@ -87,6 +84,7 @@ FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE) } // namespace ir +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::UInt8Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int8Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::VectorType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BFloat16Type) @@ -97,3 +95,5 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type) diff --git a/paddle/ir/core/ir_context.cc b/paddle/ir/core/ir_context.cc index 583eb0a19e1..6f4399ca8dc 100644 --- a/paddle/ir/core/ir_context.cc +++ b/paddle/ir/core/ir_context.cc @@ -156,9 +156,14 @@ class IrContextImpl { Float16Type fp16_type; Float32Type fp32_type; Float64Type fp64_type; + UInt8Type uint8_type; + Int8Type int8_type; Int16Type int16_type; Int32Type int32_type; Int64Type int64_type; + BoolType bool_type; + Complex64Type complex64_type; + Complex128Type complex128_type; // Cached AbstractAttribute instances. std::unordered_map registed_abstract_attributes_; @@ -193,9 +198,14 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { impl_->fp16_type = TypeManager::get(this); impl_->fp32_type = TypeManager::get(this); impl_->fp64_type = TypeManager::get(this); + impl_->uint8_type = TypeManager::get(this); + impl_->int8_type = TypeManager::get(this); impl_->int16_type = TypeManager::get(this); impl_->int32_type = TypeManager::get(this); impl_->int64_type = TypeManager::get(this); + impl_->bool_type = TypeManager::get(this); + impl_->complex64_type = TypeManager::get(this); + impl_->complex128_type = TypeManager::get(this); } StorageManager &IrContext::type_storage_manager() { @@ -336,4 +346,18 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; } +Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; } + +UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; } + +BoolType BoolType::get(IrContext *ctx) { return ctx->impl().bool_type; } + +Complex64Type Complex64Type::get(IrContext *ctx) { + return ctx->impl().complex64_type; +} + +Complex128Type Complex128Type::get(IrContext *ctx) { + return ctx->impl().complex128_type; +} + } // namespace ir diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index bb7a0c9e825..5ddb7abc1b5 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -39,18 +39,30 @@ void BasicIrPrinter::PrintType(Type type) { return; } - if (type.isa()) { + if (type.isa()) { + os << "bf16"; + } else if (type.isa()) { os << "f16"; } else if (type.isa()) { os << "f32"; } else if (type.isa()) { os << "f64"; + } else if (type.isa()) { + os << "b"; + } else if (type.isa()) { + os << "i8"; + } else if (type.isa()) { + os << "u8"; } else if (type.isa()) { os << "i16"; } else if (type.isa()) { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "c64"; + } else if (type.isa()) { + os << "c128"; } else if (type.isa()) { os << "vec["; auto inner_types = type.dyn_cast().data(); diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index 4987348bf82..4a85007a623 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -84,3 +84,13 @@ cc_test_old( pd_dialect pd_interface ir) + +cc_test_old( + ir_type_converter_test + SRCS + ir_type_converter_test.cc + DEPS + gtest + program_translator + pd_dialect + ir) diff --git a/test/cpp/ir/core/ir_type_converter_test.cc b/test/cpp/ir/core/ir_type_converter_test.cc new file mode 100644 index 00000000000..896c1059dc6 --- /dev/null +++ b/test/cpp/ir/core/ir_type_converter_test.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/ir/dialect/utils.h" +#include "paddle/fluid/ir_adaptor/translator/type_translator.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/builtin_type.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/type.h" + +template +void test_parameterless_type() { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + ir::Type type = IR_TYPE::get(ctx); + std::stringstream ss; + ss << type; + EXPECT_GT(ss.str().size(), 0u); + EXPECT_NE(ss.str(), "<>"); + phi::DataType phi_type = paddle::dialect::TransToPhiDataType(type); + EXPECT_EQ(type, paddle::dialect::TransToIrDataType(phi_type)); + + auto& type_translator = paddle::translator::TypeTranslator::instance(); + paddle::framework::VarDesc empty_var_desc("empty"); + auto proto_type = paddle::framework::TransToProtoVarType(phi_type); + ir::Type final_type = type_translator[proto_type](ctx, empty_var_desc); + EXPECT_EQ(type, final_type); +} + +template +void test_parameterless_type_helper() { + (void)std::initializer_list{0, + (test_parameterless_type(), 0)...}; +} + +TEST(TypeConverterTest, paramterless_type) { + test_parameterless_type_helper(); +} -- GitLab