From cb59c27835b1cfd3b776e2642c33513e519a9bcd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 27 Nov 2020 18:06:08 +0800 Subject: [PATCH] feat(mlir/ir): add more op definitions GitOrigin-RevId: 1e1285ef41142955ba7c0c0d5522e59f623f83ed --- src/jit/impl/mlir/ir/each_mode.cpp | 23 ++++++++------ src/jit/impl/mlir/ir/types.cpp | 31 +++++++------------ src/jit/impl/mlir/ir/types.h | 5 +-- src/jit/impl/mlir/ir/utils.cpp | 2 +- src/jit/impl/mlir/mlir_gen.cpp | 9 ++++-- .../include/megbrain/jit/mlir/ir/dialect.h | 3 ++ .../megbrain/jit/mlir/ir/mgb_dialect.td | 2 ++ 7 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/jit/impl/mlir/ir/each_mode.cpp b/src/jit/impl/mlir/ir/each_mode.cpp index f48d15461..f2d8c12c6 100644 --- a/src/jit/impl/mlir/ir/each_mode.cpp +++ b/src/jit/impl/mlir/ir/each_mode.cpp @@ -22,6 +22,7 @@ #include "megbrain/exception.h" #include "megbrain/jit/mlir/ir/dialect.h" +#include #include namespace mgb { @@ -442,31 +443,35 @@ mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder, mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc, mlir::Value input) { auto&& typecvt = llvm::dyn_cast(op); - megdnn::DType idtype = typecvt.idtype(); - megdnn::DType odtype = typecvt.odtype(); + mlir::Type idtype = typecvt.idtype(); + mlir::Type odtype = + megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext()); mlir::Type itype = input.getType(); - mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext()); + mlir::Type otype = signless(odtype); + mgb_assert(signless(idtype) == itype); if (mlir::FPExtOp::areCastCompatible(itype, otype)) { return builder.create(loc, otype, input); } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { return builder.create(loc, otype, input); } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and - is_signed_int_dtype(odtype)) { + odtype.isSignedInteger()) { return builder.create(loc, otype, input); } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and - is_unsigned_int_dtype(odtype)) { + odtype.isUnsignedInteger()) { return builder.create(loc, otype, input); } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and - is_signed_int_dtype(idtype)) { + idtype.isSignedInteger()) { return builder.create(loc, otype, input); } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and - is_unsigned_int_dtype(idtype)) { + idtype.isUnsignedInteger()) { return builder.create(loc, otype, input); } else { - mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(), - odtype.name()); + std::string tmp; + llvm::raw_string_ostream os(tmp); + os << "cannot convert from " << idtype << " to " << odtype; + mgb_throw_raw(InternalError{tmp}); } return nullptr; diff --git a/src/jit/impl/mlir/ir/types.cpp b/src/jit/impl/mlir/ir/types.cpp index c353e6d85..a04755c50 100644 --- a/src/jit/impl/mlir/ir/types.cpp +++ b/src/jit/impl/mlir/ir/types.cpp @@ -28,13 +28,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, case megdnn::DTypeEnum::Float32: return mlir::FloatType::getF32(ctx); case megdnn::DTypeEnum::Uint8: - return mlir::IntegerType::get(8, ctx); + return mlir::IntegerType::get(8, mlir::IntegerType::Unsigned, ctx); case megdnn::DTypeEnum::Int8: - return mlir::IntegerType::get(8, ctx); + return mlir::IntegerType::get(8, mlir::IntegerType::Signed, ctx); case megdnn::DTypeEnum::Int16: - return mlir::IntegerType::get(16, ctx); + return mlir::IntegerType::get(16, mlir::IntegerType::Signed, ctx); case megdnn::DTypeEnum::Int32: - return mlir::IntegerType::get(32, ctx); + return mlir::IntegerType::get(32, mlir::IntegerType::Signed, ctx); case megdnn::DTypeEnum::IntB1: return mlir::IntegerType::get(1, ctx); case megdnn::DTypeEnum::IntB2: @@ -57,6 +57,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, } } +mlir::Type signless(mlir::Type type) { + if (auto intty = type.dyn_cast()) { + return mlir::IntegerType::get(intty.getWidth(), type.getContext()); + } + return type; +} + megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { mlir::Type element_type = type; if (auto cast = type.dyn_cast_or_null()) { @@ -91,22 +98,6 @@ megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { return megdnn::DType::from_enum(enumv); } -bool is_signed_int_dtype(megdnn::DType type) { - auto enumv = type.enumv(); - return enumv == megdnn::DTypeEnum::Int8 or - enumv == megdnn::DTypeEnum::Int16 or - enumv == megdnn::DTypeEnum::Int32 or - enumv == megdnn::DTypeEnum::IntB1 or - enumv == megdnn::DTypeEnum::IntB2 or - enumv == megdnn::DTypeEnum::IntB4; -} - -bool is_unsigned_int_dtype(megdnn::DType type) { - auto enumv = type.enumv(); - return enumv == megdnn::DTypeEnum::Uint8 or - enumv == megdnn::DTypeEnum::UintB4; -} - } // namespace jit } // namespace mgb diff --git a/src/jit/impl/mlir/ir/types.h b/src/jit/impl/mlir/ir/types.h index 2dff735e7..dacf2dea0 100644 --- a/src/jit/impl/mlir/ir/types.h +++ b/src/jit/impl/mlir/ir/types.h @@ -35,13 +35,10 @@ namespace jit { mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, mlir::MLIRContext* ctx); +mlir::Type signless(mlir::Type type); megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); -bool is_signed_int_dtype(megdnn::DType type); - -bool is_unsigned_int_dtype(megdnn::DType type); - } // namespace jit } // namespace mgb diff --git a/src/jit/impl/mlir/ir/utils.cpp b/src/jit/impl/mlir/ir/utils.cpp index 24e8d070c..3b40fc902 100644 --- a/src/jit/impl/mlir/ir/utils.cpp +++ b/src/jit/impl/mlir/ir/utils.cpp @@ -87,7 +87,7 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, shape.push_back(layout[i]); } mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); - return mlir::MemRefType::get(shape, type); + return mlir::MemRefType::get(shape, signless(type)); } #endif // MGB_JIT && MGB_JIT_MLIR diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp index 58205a65e..234d289d7 100644 --- a/src/jit/impl/mlir/mlir_gen.cpp +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -197,12 +197,15 @@ private: .getType() .dyn_cast_or_null(); mgb_assert(itype, "currently only support MemRefType"); + auto output_type = megdnn_dtype_to_mlir_type(opr.param(), + m_builder.getContext()); auto res_type = mlir::MemRefType::get( - itype.getShape(), - megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); + itype.getShape(), signless(output_type)); + auto inp_type = megdnn_dtype_to_mlir_type(opr.input(0)->dtype(), + m_builder.getContext()); return m_builder.create( m_builder.getUnknownLoc(), res_type, get(opr.input(0)), - opr.input(0)->dtype(), opr.param()); + mlir::TypeAttr::get(inp_type), opr.param()); } mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { diff --git a/src/jit/include/megbrain/jit/mlir/ir/dialect.h b/src/jit/include/megbrain/jit/mlir/ir/dialect.h index 4ce39f895..12c5ca73b 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/dialect.h +++ b/src/jit/include/megbrain/jit/mlir/ir/dialect.h @@ -15,7 +15,10 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR +#include "megdnn/basic_types.h" #include "megdnn/opr_param_defs.h" +#include "megbrain/opr/param_defs.h" +#include "megbrain/comp_node.h" #include #include diff --git a/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td b/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td index 7ae37ffbe..01706b706 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td +++ b/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td @@ -15,6 +15,8 @@ include "ops.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + class GenericOp traits = []> : Op; -- GitLab