未验证 提交 ffc1b027 编写于 作者: K kangguangli 提交者: GitHub

[IR] complement ir type (#54911)

* complement ir type

* fix ir_printer
上级 4588892a
...@@ -26,18 +26,30 @@ namespace dialect { ...@@ -26,18 +26,30 @@ namespace dialect {
// TODO(zhangbo): The builtin type needs to cover all data types of // TODO(zhangbo): The builtin type needs to cover all data types of
// phi::DataType. // phi::DataType.
static inline phi::DataType TransToPhiDataType(ir::Type dtype) { static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
if (dtype.isa<ir::Float16Type>()) { if (dtype.isa<ir::BFloat16Type>()) {
return phi::DataType::BFLOAT16;
} else if (dtype.isa<ir::Float16Type>()) {
return phi::DataType::FLOAT16; return phi::DataType::FLOAT16;
} else if (dtype.isa<ir::Float32Type>()) { } else if (dtype.isa<ir::Float32Type>()) {
return phi::DataType::FLOAT32; return phi::DataType::FLOAT32;
} else if (dtype.isa<ir::Float64Type>()) { } else if (dtype.isa<ir::Float64Type>()) {
return phi::DataType::FLOAT64; return phi::DataType::FLOAT64;
} else if (dtype.isa<ir::UInt8Type>()) {
return phi::DataType::UINT8;
} else if (dtype.isa<ir::Int8Type>()) {
return phi::DataType::INT8;
} else if (dtype.isa<ir::Int16Type>()) { } else if (dtype.isa<ir::Int16Type>()) {
return phi::DataType::INT16; return phi::DataType::INT16;
} else if (dtype.isa<ir::Int32Type>()) { } else if (dtype.isa<ir::Int32Type>()) {
return phi::DataType::INT32; return phi::DataType::INT32;
} else if (dtype.isa<ir::Int64Type>()) { } else if (dtype.isa<ir::Int64Type>()) {
return phi::DataType::INT64; return phi::DataType::INT64;
} else if (dtype.isa<ir::BoolType>()) {
return phi::DataType::BOOL;
} else if (dtype.isa<ir::Complex64Type>()) {
return phi::DataType::COMPLEX64;
} else if (dtype.isa<ir::Complex128Type>()) {
return phi::DataType::COMPLEX128;
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir data type when casting it into " "Unsupported ir data type when casting it into "
...@@ -51,18 +63,30 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype, ...@@ -51,18 +63,30 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype,
ctx = ir::IrContext::Instance(); ctx = ir::IrContext::Instance();
} }
switch (dtype) { switch (dtype) {
case phi::DataType::BFLOAT16:
return ir::BFloat16Type::get(ctx);
case phi::DataType::FLOAT16: case phi::DataType::FLOAT16:
return ir::Float16Type::get(ctx); return ir::Float16Type::get(ctx);
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
return ir::Float32Type::get(ctx); return ir::Float32Type::get(ctx);
case phi::DataType::FLOAT64: case phi::DataType::FLOAT64:
return ir::Float64Type::get(ctx); return ir::Float64Type::get(ctx);
case phi::DataType::UINT8:
return ir::UInt8Type::get(ctx);
case phi::DataType::INT8:
return ir::Int8Type::get(ctx);
case phi::DataType::INT16: case phi::DataType::INT16:
return ir::Int16Type::get(ctx); return ir::Int16Type::get(ctx);
case phi::DataType::INT32: case phi::DataType::INT32:
return ir::Int32Type::get(ctx); return ir::Int32Type::get(ctx);
case phi::DataType::INT64: case phi::DataType::INT64:
return ir::Int64Type::get(ctx); return ir::Int64Type::get(ctx);
case phi::DataType::BOOL:
return ir::BoolType::get(ctx);
case phi::DataType::COMPLEX64:
return ir::Complex64Type::get(ctx);
case phi::DataType::COMPLEX128:
return ir::Complex128Type::get(ctx);
default: default:
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported phi data type `%s` when casting it into " "Unsupported phi data type `%s` when casting it into "
......
...@@ -327,6 +327,8 @@ inline std::vector<ir::OpResult> GenerateOperationInput( ...@@ -327,6 +327,8 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
} }
bool is_vector = (info.type_name.find("VectorType") != std::string::npos); bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
is_vector |=
(info.type_name.find("IntArrayAttribute") != std::string::npos);
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< is_vector << " " << info.type_name; << is_vector << " " << info.type_name;
......
...@@ -31,10 +31,34 @@ using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; ...@@ -31,10 +31,34 @@ using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
TypeTranslator::TypeTranslator() { TypeTranslator::TypeTranslator() {
handlers = { handlers = {
{VarType::BOOL,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::BoolType::get(ctx);
}},
{VarType::UINT8,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::UInt8Type::get(ctx);
}},
{VarType::INT8,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Int8Type::get(ctx);
}},
{VarType::INT16,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Int16Type::get(ctx);
}},
{VarType::INT32,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Int32Type::get(ctx);
}},
{VarType::INT64, {VarType::INT64,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Int64Type::get(ctx); return ir::Int64Type::get(ctx);
}}, }},
{VarType::FP16,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Float16Type::get(ctx);
}},
{VarType::FP32, {VarType::FP32,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Float32Type::get(ctx); return ir::Float32Type::get(ctx);
...@@ -43,10 +67,22 @@ TypeTranslator::TypeTranslator() { ...@@ -43,10 +67,22 @@ TypeTranslator::TypeTranslator() {
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Float64Type::get(ctx); return ir::Float64Type::get(ctx);
}}, }},
{VarType::BF16,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::BFloat16Type::get(ctx);
}},
{VarType::COMPLEX64,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Complex64Type::get(ctx);
}},
{VarType::COMPLEX128,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
return ir::Complex128Type::get(ctx);
}},
{VarType::LOD_TENSOR, {VarType::LOD_TENSOR,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
VLOG(10) << "[vartype translating]" VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "]" << var_desc.GetDataType(); << "[" << var_desc.Name() << "] from LOD_TENSOR";
ir::Type dtype = ir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc); this->operator[](var_desc.GetDataType())(ctx, var_desc);
......
...@@ -30,10 +30,13 @@ void BuiltinDialect::initialize() { ...@@ -30,10 +30,13 @@ void BuiltinDialect::initialize() {
Float32Type, Float32Type,
Float64Type, Float64Type,
Int8Type, Int8Type,
UInt8Type,
Int16Type, Int16Type,
Int32Type, Int32Type,
Int64Type, Int64Type,
BoolType, BoolType,
Complex64Type,
Complex128Type,
VectorType>(); VectorType>();
RegisterAttributes<StrAttribute, RegisterAttributes<StrAttribute,
......
...@@ -19,6 +19,7 @@ std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); } ...@@ -19,6 +19,7 @@ std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); }
} // namespace ir } // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::UInt8Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int8Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int8Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::VectorType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::VectorType)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::BFloat16Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BFloat16Type)
...@@ -29,3 +30,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type) ...@@ -29,3 +30,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type)
...@@ -38,13 +38,6 @@ namespace ir { ...@@ -38,13 +38,6 @@ namespace ir {
// NOTE(dev): Currently Int8 are not considered as a cached member // NOTE(dev): Currently Int8 are not considered as a cached member
// in IrContextImpl because it is not widely used. // in IrContextImpl because it is not widely used.
class IR_API Int8Type : public Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int8Type, TypeStorage);
};
class IR_API VectorType : public Type { class IR_API VectorType : public Type {
public: public:
using Type::Type; using Type::Type;
...@@ -75,10 +68,14 @@ class IR_API VectorType : public Type { ...@@ -75,10 +68,14 @@ class IR_API VectorType : public Type {
__macro(Float16Type); \ __macro(Float16Type); \
__macro(Float32Type); \ __macro(Float32Type); \
__macro(Float64Type); \ __macro(Float64Type); \
__macro(Int8Type); \
__macro(UInt8Type); \
__macro(Int16Type); \ __macro(Int16Type); \
__macro(Int32Type); \ __macro(Int32Type); \
__macro(Int64Type); \ __macro(Int64Type); \
__macro(BoolType); __macro(BoolType); \
__macro(Complex64Type); \
__macro(Complex128Type);
FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE) FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE)
...@@ -87,6 +84,7 @@ FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE) ...@@ -87,6 +84,7 @@ FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE)
} // namespace ir } // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::UInt8Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int8Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int8Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::VectorType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::VectorType)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BFloat16Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BFloat16Type)
...@@ -97,3 +95,5 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type) ...@@ -97,3 +95,5 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type)
...@@ -156,9 +156,14 @@ class IrContextImpl { ...@@ -156,9 +156,14 @@ class IrContextImpl {
Float16Type fp16_type; Float16Type fp16_type;
Float32Type fp32_type; Float32Type fp32_type;
Float64Type fp64_type; Float64Type fp64_type;
UInt8Type uint8_type;
Int8Type int8_type;
Int16Type int16_type; Int16Type int16_type;
Int32Type int32_type; Int32Type int32_type;
Int64Type int64_type; Int64Type int64_type;
BoolType bool_type;
Complex64Type complex64_type;
Complex128Type complex128_type;
// Cached AbstractAttribute instances. // Cached AbstractAttribute instances.
std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_; std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_;
...@@ -193,9 +198,14 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { ...@@ -193,9 +198,14 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
impl_->fp16_type = TypeManager::get<Float16Type>(this); impl_->fp16_type = TypeManager::get<Float16Type>(this);
impl_->fp32_type = TypeManager::get<Float32Type>(this); impl_->fp32_type = TypeManager::get<Float32Type>(this);
impl_->fp64_type = TypeManager::get<Float64Type>(this); impl_->fp64_type = TypeManager::get<Float64Type>(this);
impl_->uint8_type = TypeManager::get<UInt8Type>(this);
impl_->int8_type = TypeManager::get<Int8Type>(this);
impl_->int16_type = TypeManager::get<Int16Type>(this); impl_->int16_type = TypeManager::get<Int16Type>(this);
impl_->int32_type = TypeManager::get<Int32Type>(this); impl_->int32_type = TypeManager::get<Int32Type>(this);
impl_->int64_type = TypeManager::get<Int64Type>(this); impl_->int64_type = TypeManager::get<Int64Type>(this);
impl_->bool_type = TypeManager::get<BoolType>(this);
impl_->complex64_type = TypeManager::get<Complex64Type>(this);
impl_->complex128_type = TypeManager::get<Complex128Type>(this);
} }
StorageManager &IrContext::type_storage_manager() { StorageManager &IrContext::type_storage_manager() {
...@@ -336,4 +346,18 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } ...@@ -336,4 +346,18 @@ Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; } Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }
Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; }
UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; }
BoolType BoolType::get(IrContext *ctx) { return ctx->impl().bool_type; }
Complex64Type Complex64Type::get(IrContext *ctx) {
return ctx->impl().complex64_type;
}
Complex128Type Complex128Type::get(IrContext *ctx) {
return ctx->impl().complex128_type;
}
} // namespace ir } // namespace ir
...@@ -39,18 +39,30 @@ void BasicIrPrinter::PrintType(Type type) { ...@@ -39,18 +39,30 @@ void BasicIrPrinter::PrintType(Type type) {
return; return;
} }
if (type.isa<Float16Type>()) { if (type.isa<BFloat16Type>()) {
os << "bf16";
} else if (type.isa<Float16Type>()) {
os << "f16"; os << "f16";
} else if (type.isa<Float32Type>()) { } else if (type.isa<Float32Type>()) {
os << "f32"; os << "f32";
} else if (type.isa<Float64Type>()) { } else if (type.isa<Float64Type>()) {
os << "f64"; os << "f64";
} else if (type.isa<BoolType>()) {
os << "b";
} else if (type.isa<Int8Type>()) {
os << "i8";
} else if (type.isa<UInt8Type>()) {
os << "u8";
} else if (type.isa<Int16Type>()) { } else if (type.isa<Int16Type>()) {
os << "i16"; os << "i16";
} else if (type.isa<Int32Type>()) { } else if (type.isa<Int32Type>()) {
os << "i32"; os << "i32";
} else if (type.isa<Int64Type>()) { } else if (type.isa<Int64Type>()) {
os << "i64"; os << "i64";
} else if (type.isa<Complex64Type>()) {
os << "c64";
} else if (type.isa<Complex128Type>()) {
os << "c128";
} else if (type.isa<VectorType>()) { } else if (type.isa<VectorType>()) {
os << "vec["; os << "vec[";
auto inner_types = type.dyn_cast<VectorType>().data(); auto inner_types = type.dyn_cast<VectorType>().data();
......
...@@ -84,3 +84,13 @@ cc_test_old( ...@@ -84,3 +84,13 @@ cc_test_old(
pd_dialect pd_dialect
pd_interface pd_interface
ir) ir)
cc_test_old(
ir_type_converter_test
SRCS
ir_type_converter_test.cc
DEPS
gtest
program_translator
pd_dialect
ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <sstream>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/type.h"
template <typename IR_TYPE>
void test_parameterless_type() {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
ir::Type type = IR_TYPE::get(ctx);
std::stringstream ss;
ss << type;
EXPECT_GT(ss.str().size(), 0u);
EXPECT_NE(ss.str(), "<<NULL TYPE>>");
phi::DataType phi_type = paddle::dialect::TransToPhiDataType(type);
EXPECT_EQ(type, paddle::dialect::TransToIrDataType(phi_type));
auto& type_translator = paddle::translator::TypeTranslator::instance();
paddle::framework::VarDesc empty_var_desc("empty");
auto proto_type = paddle::framework::TransToProtoVarType(phi_type);
ir::Type final_type = type_translator[proto_type](ctx, empty_var_desc);
EXPECT_EQ(type, final_type);
}
template <typename... IR_TYPE>
void test_parameterless_type_helper() {
(void)std::initializer_list<int>{0,
(test_parameterless_type<IR_TYPE>(), 0)...};
}
TEST(TypeConverterTest, paramterless_type) {
test_parameterless_type_helper<ir::UInt8Type,
ir::Int8Type,
ir::BFloat16Type,
ir::Float16Type,
ir::Float32Type,
ir::Float64Type,
ir::Int16Type,
ir::Int32Type,
ir::Int64Type,
ir::BoolType,
ir::Complex64Type,
ir::Complex128Type>();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册