diff --git a/src/jit/impl/mlir/ir/each_mode.cpp b/src/jit/impl/mlir/ir/each_mode.cpp index f48d1546192deddc0d87c8fd3a3c3e418c11ec9e..f2d8c12c6e208d1b76c0f5c89e31a08caa268bfa 100644 --- a/src/jit/impl/mlir/ir/each_mode.cpp +++ b/src/jit/impl/mlir/ir/each_mode.cpp @@ -22,6 +22,7 @@ #include "megbrain/exception.h" #include "megbrain/jit/mlir/ir/dialect.h" +#include #include namespace mgb { @@ -442,31 +443,35 @@ mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder, mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc, mlir::Value input) { auto&& typecvt = llvm::dyn_cast(op); - megdnn::DType idtype = typecvt.idtype(); - megdnn::DType odtype = typecvt.odtype(); + mlir::Type idtype = typecvt.idtype(); + mlir::Type odtype = + megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext()); mlir::Type itype = input.getType(); - mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext()); + mlir::Type otype = signless(odtype); + mgb_assert(signless(idtype) == itype); if (mlir::FPExtOp::areCastCompatible(itype, otype)) { return builder.create(loc, otype, input); } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { return builder.create(loc, otype, input); } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and - is_signed_int_dtype(odtype)) { + odtype.isSignedInteger()) { return builder.create(loc, otype, input); } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and - is_unsigned_int_dtype(odtype)) { + odtype.isUnsignedInteger()) { return builder.create(loc, otype, input); } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and - is_signed_int_dtype(idtype)) { + idtype.isSignedInteger()) { return builder.create(loc, otype, input); } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and - is_unsigned_int_dtype(idtype)) { + idtype.isUnsignedInteger()) { return builder.create(loc, otype, input); } else { - mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(), - odtype.name()); + std::string tmp; + llvm::raw_string_ostream os(tmp); + os << "cannot convert from " << idtype << " to " << odtype; + mgb_throw_raw(InternalError{tmp}); } return nullptr; diff --git a/src/jit/impl/mlir/ir/types.cpp b/src/jit/impl/mlir/ir/types.cpp index c353e6d85f8f01d2fe76d3b8bbfe15971dd0a2a2..a04755c50ed8c6c0cb09f6b23f07d124cc105303 100644 --- a/src/jit/impl/mlir/ir/types.cpp +++ b/src/jit/impl/mlir/ir/types.cpp @@ -28,13 +28,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, case megdnn::DTypeEnum::Float32: return mlir::FloatType::getF32(ctx); case megdnn::DTypeEnum::Uint8: - return mlir::IntegerType::get(8, ctx); + return mlir::IntegerType::get(8, mlir::IntegerType::Unsigned, ctx); case megdnn::DTypeEnum::Int8: - return mlir::IntegerType::get(8, ctx); + return mlir::IntegerType::get(8, mlir::IntegerType::Signed, ctx); case megdnn::DTypeEnum::Int16: - return mlir::IntegerType::get(16, ctx); + return mlir::IntegerType::get(16, mlir::IntegerType::Signed, ctx); case megdnn::DTypeEnum::Int32: - return mlir::IntegerType::get(32, ctx); + return mlir::IntegerType::get(32, mlir::IntegerType::Signed, ctx); case megdnn::DTypeEnum::IntB1: return mlir::IntegerType::get(1, ctx); case megdnn::DTypeEnum::IntB2: @@ -57,6 +57,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, } } +mlir::Type signless(mlir::Type type) { + if (auto intty = type.dyn_cast()) { + return mlir::IntegerType::get(intty.getWidth(), type.getContext()); + } + return type; +} + megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { mlir::Type element_type = type; if (auto cast = type.dyn_cast_or_null()) { @@ -91,22 +98,6 @@ megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { return megdnn::DType::from_enum(enumv); } -bool is_signed_int_dtype(megdnn::DType type) { - auto enumv = type.enumv(); - return enumv == megdnn::DTypeEnum::Int8 or - enumv == megdnn::DTypeEnum::Int16 or - enumv == megdnn::DTypeEnum::Int32 or - enumv == megdnn::DTypeEnum::IntB1 or - enumv == megdnn::DTypeEnum::IntB2 or - enumv == megdnn::DTypeEnum::IntB4; -} - -bool is_unsigned_int_dtype(megdnn::DType type) { - auto enumv = type.enumv(); - return enumv == megdnn::DTypeEnum::Uint8 or - enumv == megdnn::DTypeEnum::UintB4; -} - } // namespace jit } // namespace mgb diff --git a/src/jit/impl/mlir/ir/types.h b/src/jit/impl/mlir/ir/types.h index 2dff735e7a5a632f01e0c16f501dd495c5506d55..dacf2dea0d7ccae55e65d3844654667790893f30 100644 --- a/src/jit/impl/mlir/ir/types.h +++ b/src/jit/impl/mlir/ir/types.h @@ -35,13 +35,10 @@ namespace jit { mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, mlir::MLIRContext* ctx); +mlir::Type signless(mlir::Type type); megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); -bool is_signed_int_dtype(megdnn::DType type); - -bool is_unsigned_int_dtype(megdnn::DType type); - } // namespace jit } // namespace mgb diff --git a/src/jit/impl/mlir/ir/utils.cpp b/src/jit/impl/mlir/ir/utils.cpp index 24e8d070cedbd517f9ab385ef80c916c53abf2b1..3b40fc90216e8c780f33c95325cd07392731a937 100644 --- a/src/jit/impl/mlir/ir/utils.cpp +++ b/src/jit/impl/mlir/ir/utils.cpp @@ -87,7 +87,7 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, shape.push_back(layout[i]); } mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); - return mlir::MemRefType::get(shape, type); + return mlir::MemRefType::get(shape, signless(type)); } #endif // MGB_JIT && MGB_JIT_MLIR diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp index 58205a65e8a1102fc3f417877b69c15c5e46ba86..234d289d7ad3165e234fc60495b3c5ae796554af 100644 --- a/src/jit/impl/mlir/mlir_gen.cpp +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -197,12 +197,15 @@ private: .getType() .dyn_cast_or_null(); mgb_assert(itype, "currently only support MemRefType"); + auto output_type = megdnn_dtype_to_mlir_type(opr.param(), + m_builder.getContext()); auto res_type = mlir::MemRefType::get( - itype.getShape(), - megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); + itype.getShape(), signless(output_type)); + auto inp_type = megdnn_dtype_to_mlir_type(opr.input(0)->dtype(), + m_builder.getContext()); return m_builder.create( m_builder.getUnknownLoc(), res_type, get(opr.input(0)), - opr.input(0)->dtype(), opr.param()); + mlir::TypeAttr::get(inp_type), opr.param()); } mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { diff --git a/src/jit/include/megbrain/jit/mlir/ir/dialect.h b/src/jit/include/megbrain/jit/mlir/ir/dialect.h index 4ce39f895b94722f88ae5aeded2f63b246168d95..12c5ca73bd266f84c8f9eb49f19b9c67b1253e42 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/dialect.h +++ b/src/jit/include/megbrain/jit/mlir/ir/dialect.h @@ -15,7 +15,10 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR +#include "megdnn/basic_types.h" #include "megdnn/opr_param_defs.h" +#include "megbrain/opr/param_defs.h" +#include "megbrain/comp_node.h" #include #include diff --git a/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td b/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td index 7ae37ffbea9c107669b86d0ed98c169d534b9ddc..01706b7066ab6452d48c04ee1418edd1fedbcb2a 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td +++ b/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td @@ -15,6 +15,8 @@ include "ops.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + class GenericOp traits = []> : Op;