diff --git a/CMakeLists.txt b/CMakeLists.txt index 1072cd9766080409a44ec50ea57d784f982ae111..155a10f90aa1eef356b9007e480c05e31c1bfd91 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -753,12 +753,14 @@ install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) if(MGE_WITH_JIT_MLIR) # generate param_defs.td set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) + set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir) set(OPR_PARAM_DEFS_SRCS ${MGE_GENFILE_DIR}/opr_param_defs.py) set(OPR_PARAM_DEFS_SCRIPT ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_tablegen.py) - set(OPR_PARAM_DEFS_OUT ${MGE_GENFILE_DIR}/param_defs.td) + set(OPR_PARAM_DEFS_OUT ${MGE_GEN_IR_DIR}/param_defs.td) file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${MGE_GENFILE_DIR}) file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS}) + file(MAKE_DIRECTORY ${MGE_GEN_IR_DIR}) add_custom_target(param_defs_tblgen COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT} DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} @@ -766,7 +768,7 @@ if(MGE_WITH_JIT_MLIR) ) # mlir tblgen sources set(MGE_IR_DIR ${PROJECT_SOURCE_DIR}/src/core/include/megbrain/ir) - set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_GENFILE_DIR} ${MGE_IR_DIR}) + set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_IR_DIR} ${MGE_GEN_IR_DIR}) list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I") file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bc27e8fa01b05f542077b722ac8f786533b7a57f..a6000392f8c24029e268c28973385825651b4a37 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,5 @@ if(MGE_WITH_JIT_MLIR) - add_subdirectory(jit/impl/mlir/ir) + add_subdirectory(jit/include/megbrain/jit/mlir/ir) endif() file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp opr/impl/nvof/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl) @@ -100,9 +100,10 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS}) endif() if(MGE_WITH_JIT_MLIR) - target_link_libraries(megbrain PRIVATE mlir_op_def) - target_link_libraries(megbrain PRIVATE mlir_shape_inference) + target_include_directories(megbrain PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) target_link_libraries(megbrain PRIVATE ${MLIR_LLVM_LIBS}) + add_dependencies(megbrain mgb_dialect) + target_include_directories(megbrain PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/include) endif() if (MGB_WITH_FLATBUFFERS) set (GEN_FLATBUFFERS_SCHEMA_PY ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_flatbuffers_schema.py) diff --git a/src/jit/impl/mlir/compiler.cpp b/src/jit/impl/mlir/compiler.cpp index 6f4588d15438450a67504b707cd81e6581c2a355..b6b4809c0620f5679490191eac49ced7ce005be0 100644 --- a/src/jit/impl/mlir/compiler.cpp +++ b/src/jit/impl/mlir/compiler.cpp @@ -17,6 +17,7 @@ #include "./executable_cpu.h" #include "./executable_cuda.h" #include "./mlir_gen.h" + #include "megbrain/common.h" #include "megbrain/comp_node_env.h" #include "megbrain/jit/mlir/ir/dialect.h" diff --git a/src/jit/impl/mlir/executable_cpu.cpp b/src/jit/impl/mlir/executable_cpu.cpp index 35c8368e4de55bb099fed3aca6f444384fd30ead..98c025b4b2460fceed51a26c0fc964281c24c92c 100644 --- a/src/jit/impl/mlir/executable_cpu.cpp +++ b/src/jit/impl/mlir/executable_cpu.cpp @@ -14,37 +14,44 @@ #if MGB_JIT && MGB_JIT_MLIR #include "./executable_cpu.h" +#include "./ir/types.h" + #include "megbrain/jit/mlir/ir/utils.h" -#include #include +#include using namespace mgb; using namespace jit; namespace { +template +StridedMemRefType* get_strided_memref_type( + const megdnn::TensorND& tensor) { + using DescType = StridedMemRefType; + DescType* desc = static_cast(malloc(sizeof(DescType))); + desc->basePtr = tensor.ptr(); + desc->data = tensor.ptr(); + desc->offset = 0; + for (size_t i = 0; i < tensor.layout.ndim; i++) { + desc->sizes[i] = tensor.layout.shape[i]; + desc->strides[i] = tensor.layout.stride[i]; + } + return desc; +} + template void* tensor2memref_dim(const megdnn::TensorND& tensor) { switch (tensor.layout.dtype.enumv()) { - case megdnn::DTypeEnum::Float32: { - StridedMemRefType* desc = - static_cast*>( - malloc(sizeof(StridedMemRefType))); - desc->basePtr = tensor.ptr(); - desc->data = tensor.ptr(); - desc->offset = 0; - for (size_t i = 0; i < tensor.layout.ndim; i++) { - desc->sizes[i] = tensor.layout.shape[i]; - desc->strides[i] = tensor.layout.stride[i]; - } - return desc; - break; - } +#define cb(_dtype, _type) \ + case megdnn::DTypeEnum::_dtype: \ + return get_strided_memref_type<_type, N>(tensor); + FOR_EACH_DNN_DTYPE(cb) +#undef cb default: - mgb_throw(InternalError, "Unsupport dtype, got %s", + mgb_throw(InternalError, "Unsupported dtype: %s", tensor.layout.dtype.name()); - break; } return nullptr; } diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp index e7c1845f11e506f36d908278e0202a7c0bd2eb58..638ec697ae2832ac2f25fb9e0f3beee12aa454ae 100644 --- a/src/jit/impl/mlir/executable_cuda.cpp +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -10,18 +10,18 @@ * implied. */ -#include #include "megbrain_build_config.h" -#include "megdnn/dtype.h" #if MGB_JIT && MGB_JIT_MLIR - #if MGB_CUDA + #include "./executable_cuda.h" +#include "./ir/types.h" #include "megbrain/comp_node_env.h" #include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/utils/persistent_cache.h" #include "megbrain/utils/timer.h" +#include "megdnn/dtype.h" #include #include @@ -83,6 +83,24 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, env.cuda_env().stream, params.data(), 0)); } + +template +void setup_and_launch_dim(const megdnn::DType dtype, + const JITExecutor* fusion_opr, CUfunction func, + int block_size) { + switch (dtype.enumv()) { +#define cb(_dtype, _type) \ + case megdnn::DTypeEnum::_dtype: \ + setup_and_launch(fusion_opr, func, block_size); \ + return; + FOR_EACH_DNN_DTYPE(cb) +#undef cb + default: + mgb_throw(InternalError, "Unsupported dtype: %s", dtype.name()); + } + return; +} + } // namespace const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin"; @@ -136,30 +154,19 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, fusion_opr->args().outputs.size()); int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; -#define cb_outdim(_ndim, _dtype) \ - if (_ndim == out_dim) { \ - setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ - func->block_size); \ - return; \ - } - -#define cb(_dtype) \ - cb_outdim(1, float); \ - cb_outdim(2, float); \ - cb_outdim(3, float); \ - cb_outdim(4, float); \ - mgb_throw(InternalError, "unsupported out_dim=%zu", \ - static_cast(out_dim)); \ - return; - switch (dtype.enumv()) { - case DTypeEnum::Float32: - cb(float); - default: - mgb_throw(InternalError, "unsupport dtype: %s", dtype.name()); - } + switch (out_dim) { +#define cb(_ndim) \ + case _ndim: \ + setup_and_launch_dim<_ndim>(dtype, fusion_opr, func->func, \ + func->block_size); \ + break; + cb(1); + cb(2); + cb(3); + cb(4); #undef cb -#undef cb_outdim + } } #endif // MGB_CUDA diff --git a/src/jit/impl/mlir/ir/CMakeLists.txt b/src/jit/impl/mlir/ir/CMakeLists.txt deleted file mode 100644 index bd340af6e74231dcc9d174f40801c44abd9fd914..0000000000000000000000000000000000000000 --- a/src/jit/impl/mlir/ir/CMakeLists.txt +++ /dev/null @@ -1,39 +0,0 @@ -set(MGB_MLIR_TABLEGEN_INC_BASE ${CMAKE_CURRENT_BINARY_DIR}/include/) -file(MAKE_DIRECTORY ${MGB_MLIR_TABLEGEN_INC_BASE}/megbrain/jit/mlir/ir/) -list(APPEND MGB_MLIR_TABLEGEN_INC ${MGB_MLIR_TABLEGEN_INC_BASE}) - -external_tablegen_library( - NAME - mlir_shape_inference - TBLGEN - MLIR - SRCS - "interfaces.td" - INCLUDES - ${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} - OUTS - -gen-op-interface-decls include/megbrain/jit/mlir/ir/interfaces.h.inc - -gen-op-interface-defs include/megbrain/jit/mlir/ir/interfaces.cpp.inc -) - -external_tablegen_library( - NAME - mlir_op_def - TBLGEN - MLIR - SRCS - "ops.td" - INCLUDES - ${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} - OUTS - -gen-op-decls include/megbrain/jit/mlir/ir/ops.h.inc - -gen-op-defs include/megbrain/jit/mlir/ir/ops.cpp.inc -) - -# mgb_dialect -set(MGB_DIALECT_TD ${PROJECT_SOURCE_DIR}/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td) -set(LLVM_TARGET_DEFINITIONS ${MGB_DIALECT_TD}) -tablegen(MLIR mgb_dialect.h.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-decls") -tablegen(MLIR mgb_dialect.cpp.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-defs") -add_custom_target(mgb_dialect DEPENDS mgb_dialect.h.inc mgb_dialect.cpp.inc ${MGB_DIALECT_TD} ${MGE_IR_TDS}) -add_dependencies(mgb_dialect param_defs_tblgen) diff --git a/src/jit/impl/mlir/ir/common.cpp b/src/jit/impl/mlir/ir/common.cpp index 085bbff40196527eec59bb328b143b7f66a6c717..64061d9b1fe86d186c6bec26a57dddffbda539c3 100644 --- a/src/jit/impl/mlir/ir/common.cpp +++ b/src/jit/impl/mlir/ir/common.cpp @@ -14,91 +14,99 @@ #if MGB_JIT && MGB_JIT_MLIR #include "./common.h" + #include "megbrain/jit/mlir/ir/utils.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include +#include using namespace mgb; using namespace jit; +/* ===================== trivial unary functions ===================== */ + +#define cb(name, op) \ + mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ + return m_builder.create(m_location, lhs); \ + } + +cb(abs, AbsFOp); +cb(ceil, CeilFOp); +cb(cos, CosOp); +cb(exp2, Exp2Op); +cb(exp, ExpOp); +cb(floor, FloorFOp); +cb(log10, Log10Op); +cb(log2, Log2Op); +cb(log, LogOp); +cb(neg, NegFOp); +cb(rsqrt, RsqrtOp); +cb(sin, SinOp); +cb(sqrt, SqrtOp); +cb(tanh, TanhOp); + +#undef cb + +/* ===================== trivial binary functions ===================== */ + #define cb(name, op) \ mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ return m_builder.create(m_location, lhs, rhs); \ } + cb(add, AddFOp); -cb(sub, SubFOp); -cb(mul, MulFOp); -cb(div, DivFOp); -cb(divI, SignedDivIOp); -cb(mod, RemFOp); cb(bit_and, AndOp); cb(bit_or, OrOp); +cb(div, DivFOp); +cb(divI, SignedDivIOp); cb(modI, SignedRemIOp); +cb(mod, RemFOp); +cb(mul, MulFOp); +cb(sub, SubFOp); + #undef cb +/* ===================== compare functions ===================== */ + #define cb(name, mode) \ mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ return m_builder.create( \ m_location, mlir::CmpFPredicate::mode, lhs, rhs); \ } -cb(gt, OGT); + +cb(eq, OEQ); cb(ge, OGE); -cb(lt, OLT); +cb(gt, OGT); cb(le, OLE); -cb(eq, OEQ); +cb(lt, OLT); + #undef cb -mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { +mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { mlir::Value cmp = m_builder.create( - m_location, mlir::CmpFPredicate::OLT, lhs, rhs); + m_location, mlir::CmpFPredicate::OGT, lhs, rhs); return m_builder.create(m_location, cmp, lhs, rhs); } -mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { +mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { mlir::Value cmp = m_builder.create( - m_location, mlir::CmpFPredicate::OGT, lhs, rhs); + m_location, mlir::CmpFPredicate::OLT, lhs, rhs); return m_builder.create(m_location, cmp, lhs, rhs); } -mlir::Value ValueBuilderHelper::const_val(float val) { +/* ===================== constant functions ===================== */ + +mlir::Value ValueBuilderHelper::const_f32(float val) { return m_builder.create(m_location, m_builder.getF32FloatAttr(val)); } -mlir::Value ValueBuilderHelper::constI(int32_t val) { +mlir::Value ValueBuilderHelper::const_i32(int32_t val) { return m_builder.create(m_location, m_builder.getIndexAttr(val)); } -#define cb(name, op) \ - mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ - return m_builder.create(m_location, lhs); \ - } - -cb(neg, NegFOp); -cb(ceil, CeilFOp); -cb(cos, CosOp); -cb(exp, ExpOp); -cb(exp2, Exp2Op); -cb(log10, Log10Op); -cb(log2, Log2Op); -cb(log, LogOp); -cb(rsqrt, RsqrtOp); -cb(sin, SinOp); -cb(sqrt, SqrtOp); -cb(tanh, TanhOp); -#undef cb - -mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) { - auto zero = const_val(0.f); - return select(ge(lhs, zero), lhs, sub(zero, lhs)); -} - -mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { - //! FIXME use standard floor when upgrade llvm - return neg(ceil(neg(lhs))); -} +/* ===================== select function ===================== */ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, mlir::Value false_val) { @@ -106,6 +114,8 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, false_val); } +/* ===================== helper functions ===================== */ + mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, const mlir::Value& val, const megdnn::TensorLayout& layout) { @@ -125,10 +135,10 @@ mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, } mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, - const mlir::Location& loc, - const mlir::Value& val, - const mlir::ValueRange& index, - const megdnn::TensorLayout& dst) { + const mlir::Location& loc, + const mlir::Value& val, + const mlir::ValueRange& index, + const megdnn::TensorLayout& dst) { if (val.getType().isa()) { auto type = val.getType().cast(); megdnn::TensorLayout src_layout = mlir_type_to_layout(type); diff --git a/src/jit/impl/mlir/ir/common.h b/src/jit/impl/mlir/ir/common.h index edbba3302d95d86e1278791fbe4b2bc99372d024..b175c49bbe60967ee198ca20975e1191c5ad0b8c 100644 --- a/src/jit/impl/mlir/ir/common.h +++ b/src/jit/impl/mlir/ir/common.h @@ -14,7 +14,9 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR + #include "megbrain/tensor.h" + #include #include #include @@ -30,50 +32,59 @@ public: ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location) : m_builder{b}, m_location{location} {}; -#define cb(name) \ - mlir::Value name(mlir::ValueRange operands) { \ - return name(operands[0], operands[1]); \ - } \ - mlir::Value name(mlir::Value lhs, mlir::Value rhs) - cb(add); - cb(sub); - cb(mul); - cb(div); - cb(divI); - cb(max); - cb(min); - cb(mod); - cb(modI); - cb(gt); - cb(ge); - cb(lt); - cb(le); - cb(eq); - cb(bit_and); - cb(bit_or); -#undef cb - mlir::Value const_val(float val); - mlir::Value constI(int32_t val); - #define cb(name) \ mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ mlir::Value name(mlir::Value lhs) - cb(neg); + + // unary functions cb(abs); cb(ceil); - cb(floor); cb(cos); cb(exp); cb(exp2); + cb(floor); + cb(log); cb(log10); cb(log2); - cb(log); + cb(neg); cb(rsqrt); cb(sin); cb(sqrt); cb(tanh); + #undef cb +#define cb(name) \ + mlir::Value name(mlir::ValueRange operands) { \ + return name(operands[0], operands[1]); \ + } \ + mlir::Value name(mlir::Value lhs, mlir::Value rhs) + + // binary functions + cb(add); + cb(bit_and); + cb(bit_or); + cb(div); + cb(divI); + cb(eq); + cb(ge); + cb(gt); + cb(le); + cb(lt); + cb(max); + cb(min); + cb(mod); + cb(modI); + cb(mul); + cb(sub); + +#undef cb + + // constant functions + mlir::Value const_f32(float val); + mlir::Value const_i32(int32_t val); + + // select function mlir::Value select(mlir::Value cond, mlir::Value true_val, mlir::Value false_val); diff --git a/src/jit/impl/mlir/ir/dialect.cpp b/src/jit/impl/mlir/ir/dialect.cpp index c4a82bfd43de5af47f2afc746dace8c7218c5e46..68dd4e8a6ef860cc7906d111a822555c9211e284 100644 --- a/src/jit/impl/mlir/ir/dialect.cpp +++ b/src/jit/impl/mlir/ir/dialect.cpp @@ -14,6 +14,7 @@ #if MGB_JIT && MGB_JIT_MLIR #include "megbrain/jit/mlir/ir/dialect.h" + #include "./types.h" #include @@ -28,14 +29,12 @@ MgbDialect::MgbDialect(mlir::MLIRContext* ctx) : mlir::Dialect("mgb", ctx, mlir::TypeID::get()) { addOperations< #define GET_OP_LIST -#include "megbrain/jit/mlir/ir/ops.cpp.inc" +#include "megbrain/jit/mlir/ir/mgb_dialect.cpp.inc" >(); } #define GET_OP_CLASSES -#include "megbrain/jit/mlir/ir/ops.cpp.inc" - -#include "megbrain/jit/mlir/ir/interfaces.cpp.inc" +#include "megbrain/jit/mlir/ir/mgb_dialect.cpp.inc" #endif // MGB_JIT && MGB_JIT_MLIR diff --git a/src/jit/impl/mlir/ir/each_mode.cpp b/src/jit/impl/mlir/ir/each_mode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f48d1546192deddc0d87c8fd3a3c3e418c11ec9e --- /dev/null +++ b/src/jit/impl/mlir/ir/each_mode.cpp @@ -0,0 +1,480 @@ +/** + * \file src/jit/impl/mlir/ir/each_mode.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "./common.h" +#include "./each_mode.h" +#include "./numerical.h" +#include "./types.h" + +#include "megbrain/common.h" +#include "megbrain/exception.h" +#include "megbrain/jit/mlir/ir/dialect.h" + +#include + +namespace mgb { +namespace jit { + +using Mode = megdnn::param::Elemwise::Mode; + +template +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands); + +/* ===================== trivial implementations ===================== */ + +#define cb(mode, fun) \ + template <> \ + mlir::Value lower_mode(mlir::OpBuilder & builder, \ + mlir::Location loc, \ + ValueRange operands) { \ + ValueBuilderHelper helper(builder, loc); \ + return helper.fun(operands); \ + } + +//! unary +cb(ABS, abs); +cb(CEIL, ceil); +cb(COS, cos); +cb(EXP, exp); +cb(FLOOR, floor); +cb(LOG, log); +cb(NEGATE, neg); +cb(SIN, sin); +cb(TANH, tanh); + +//! binary +cb(ADD, add); +cb(MAX, max); +cb(MIN, min); +cb(MOD, mod); +cb(MUL, mul); +cb(SUB, sub); +cb(TRUE_DIV, div); + +#undef cb + +/* ===================== unary op ===================== */ + +//! ACOS: pi / 2 - arctan2(x, sqrt(1 - x * x)) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto x = operands[0]; + auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x)); + auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); + auto pi_over_2 = helper.const_f32(1.57079637f); + return helper.sub(pi_over_2, asin); +} + +//! ASIN: arctan2(x, sqrt(1 - x * x)) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto x = operands[0]; + auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x)); + return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); +} + +//! ERFCINV: inverse of complementary gauss error function +//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto minus_sqrt2 = helper.const_f32(-1.4142135623f); + auto x = helper.mul(helper.const_f32(0.5f), operands[0]); + return helper.div(ndtri_approx(helper, x), minus_sqrt2); +} + +//! ERFC: complementary error function +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0])); +} + +//! ERFINV: inverse of gauss error function +//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto sqrt2 = helper.const_f32(1.4142135623f); + auto x = helper.mul(helper.const_f32(0.5f), + helper.add(operands[0], helper.const_f32(1.f))); + return helper.div(ndtri_approx(helper, x), sqrt2); +} + +//! ERF: gauss error function +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return erf_approx(helper, operands[0]); +} + +//! EXPM1: exp(x) - 1 +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f)); +} + +//! FAST_TANH: x * (27.f + x * x) / (27.f + 9.f * x * x); +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto square = helper.mul(operands[0], operands[0]); + return helper.div( + helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)), + helper.add(helper.const_f32(27.f), + helper.mul(helper.const_f32(9.f), square))); +} + +//! H_SWISH: x * clip(x + 3, 0, 6) / 6 +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + + auto const_3 = helper.const_f32(3.f); + auto const_0 = helper.const_f32(0.f); + auto const_6 = helper.const_f32(6.f); + auto tmp = helper.add(operands[0], const_3); + return helper.div(helper.mul(operands[0], + helper.min(helper.max(tmp, const_0), const_6)), + const_6); +} + +//! LOG1P: log(1 + p) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.log(helper.add(operands[0], helper.const_f32(1.f))); +} + +//! RELU: max(x, 0) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.max(operands[0], helper.const_f32(0.f)); +} + +//! ROUND +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select( + helper.gt(operands[0], helper.const_f32(0.f)), + helper.floor(helper.add(operands[0], helper.const_f32(0.5f))), + helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f)))); +} + +//! SIGMOID: 1.f / (expf(-y) + 1.f)) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.div(helper.const_f32(1.f), + helper.add(helper.exp(helper.neg(operands[0])), + helper.const_f32(1.f))); +} + +/* ===================== binary op ===================== */ + +//! ABS_GRAD: x > 0 ? y : -y +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.gt(operands[0], helper.const_f32(0.f)), + operands[1], helper.neg(operands[1])); +} + +//! ATAN2 +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return atan2_approx(helper, operands[0], operands[1]); +} + +//! EQ: x == y ? 1 : 0 +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.eq(operands[0], operands[1]), + helper.const_f32(1.f), helper.const_f32(0.f)); +} + +//! FAST_TANH_GRAD: ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x +//! * x) * y +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto x_pow2 = helper.mul(operands[0], operands[0]); + auto deno = helper.add(helper.const_f32(3.f), x_pow2); + return helper.mul( + helper.div( + helper.add( + helper.add(helper.div(helper.mul(helper.const_f32( + -48.f), + x_pow2), + deno), + helper.const_f32(27.f)), + x_pow2), + helper.mul(deno, helper.const_f32(9.f))), + operands[1]); +} + +//! FLOOR_DIV: floor(x/y) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.floor(helper.div(operands[0], operands[1])); +} + +//! FUSE_ADD_H_SWISH: (x+y) * min(max(x + y + 3, 0), 6) * (1/6) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto sum = helper.add(operands[0], operands[1]); + + auto const_3 = helper.const_f32(3.f); + auto const_0 = helper.const_f32(0.f); + auto const_6 = helper.const_f32(6.f); + auto tmp = helper.add(sum, const_3); + return helper.div( + helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), + const_6); +} + +//! FUSE_ADD_RELU: (x + y) <= ctype(0) ? ctype(0) : (x + y) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto sum = helper.add(operands[0], operands[1]); + return helper.max(sum, helper.const_f32(0.f)); +} + +//! FUSE_ADD_SIGMOID: 1.f / (expf(-(x+y)) + 1.f)) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.div(helper.const_f32(1.f), + helper.add(helper.exp(helper.neg( + helper.add(operands[0], operands[1]))), + helper.const_f32(1.f))); +} + +//! FUSE_ADD_TANH: tanh(x + y) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.tanh(helper.add(operands[0], operands[1])); +} + +//! H_SWISH_GRAD: x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select( + helper.lt(operands[0], helper.const_f32(-3.f)), + helper.const_f32(0.f), + helper.select( + helper.gt(operands[0], helper.const_f32(3.f)), operands[1], + helper.mul( + helper.div( + helper.add(helper.mul(helper.const_f32(2.f), + operands[0]), + helper.const_f32(3.f)), + helper.const_f32(6.f)), + operands[1]))); +} + +//! LEQ: x <= y ? 1 : 0 +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.le(operands[0], operands[1]), + helper.const_f32(1.f), helper.const_f32(0.f)); +} + +//! LOG_SUM_EXP: log(exp(x) + exp(y)) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.log( + helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); +} + +//! LT: x < y ? 1 : 0 +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.lt(operands[0], operands[1]), + helper.const_f32(1.f), helper.const_f32(0.f)); +} + +//! POW: x^y = exp(y * log(x)) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); +} + +//! SIGMOID_GRAD: x * (1 - x) * y +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.mul(helper.mul(operands[0], helper.sub(helper.const_f32(1.f), + operands[0])), + operands[1]); +} + +//! SWITCH_GT0: (x > 0) * y +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.gt(operands[0], helper.const_f32(0.f)), + operands[1], helper.const_f32(0.f)); +} + +//! TANH_GRAD: (1 - x * x) * y +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.mul(helper.sub(helper.const_f32(1.0f), + helper.mul(operands[0], operands[0])), + operands[1]); +} + +/* ===================== ternary op ===================== */ + +//! COND_LEQ_MOV: x <= y ? z : ctype(0) +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.le(operands[0], operands[1]), operands[2], + helper.const_f32(0.f)); +} + +//! FUSE_MUL_ADD3: x * y + z +template <> +mlir::Value lower_mode(mlir::OpBuilder& builder, + mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.add(helper.mul(operands[0], operands[1]), operands[2]); +} + +/* ===================== elemwise ===================== */ + +mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder, + mlir::Location loc, ValueRange operands) { + auto mode = llvm::dyn_cast(op).mode(); + switch (mode) { +#define cb(_, _mode) \ + case Mode::_mode: \ + return lower_mode(builder, loc, operands); + MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb); + MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb); + MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb); + default: + return nullptr; + } +#undef cb +} + +/* ===================== typecvt ===================== */ + +mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, + mlir::Location loc, mlir::Value input) { + auto&& typecvt = llvm::dyn_cast(op); + megdnn::DType idtype = typecvt.idtype(); + megdnn::DType odtype = typecvt.odtype(); + + mlir::Type itype = input.getType(); + mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext()); + + if (mlir::FPExtOp::areCastCompatible(itype, otype)) { + return builder.create(loc, otype, input); + } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { + return builder.create(loc, otype, input); + } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and + is_signed_int_dtype(odtype)) { + return builder.create(loc, otype, input); + } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and + is_unsigned_int_dtype(odtype)) { + return builder.create(loc, otype, input); + } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and + is_signed_int_dtype(idtype)) { + return builder.create(loc, otype, input); + } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and + is_unsigned_int_dtype(idtype)) { + return builder.create(loc, otype, input); + } else { + mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(), + odtype.name()); + } + + return nullptr; +} + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/each_mode.h b/src/jit/impl/mlir/ir/each_mode.h index b7cab3ef4b9e2b95a5714d1016087d7fb8e94fdc..9ddc0da916078a93bbd77be29c5de3d2785dc0e9 100644 --- a/src/jit/impl/mlir/ir/each_mode.h +++ b/src/jit/impl/mlir/ir/each_mode.h @@ -15,65 +15,60 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR -#include "megbrain/jit/mlir/ir/dialect.h" +#include "megdnn/opr_param_defs.h" -#include "./common.h" -#include "./numerical.h" - -#include #include -#include // clang-format off #define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \ - cb(ReluOp, RELU) \ cb(AbsOp, ABS) \ - cb(NegOp, NEGATE) \ cb(AcosOp, ACOS) \ cb(AsinOp, ASIN) \ cb(CeilOp, CEIL) \ cb(CosOp, COS) \ + cb(ErfCInvOp, ERFCINV) \ + cb(ErfCOp, ERFC) \ + cb(ErfInvOp, ERFINV) \ + cb(ErfOp, ERF) \ + cb(ExpM1Op, EXPM1) \ cb(ExpOp, EXP) \ + cb(FastTanhOp, FAST_TANH) \ cb(FloorOp, FLOOR) \ - cb(LogOp, LOG) \ + cb(HswishOp, H_SWISH) \ cb(Log1POp, LOG1P) \ + cb(LogOp, LOG) \ + cb(NegOp, NEGATE) \ + cb(ReluOp, RELU) \ + cb(RoundOp, ROUND) \ cb(SigmoidOp, SIGMOID) \ cb(SinOp, SIN) \ - cb(TanhOp, TANH) \ - cb(FastTanhOp, FAST_TANH) \ - cb(HswishOp, H_SWISH) \ - cb(ExpM1Op, EXPM1) \ - cb(RoundOp, ROUND) \ - cb(ErfOp, ERF) \ - cb(ErfInvOp, ERFINV) \ - cb(ErfCOp, ERFC) \ - cb(ErfCInvOp, ERFCINV) + cb(TanhOp, TANH) #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ cb(AbsGradOp, ABS_GRAD) \ cb(AddOp, ADD) \ + cb(Atan2Op, ATAN2) \ + cb(EqOp, EQ) \ + cb(FastTanhGradOp, FAST_TANH_GRAD) \ cb(FloorDivOp, FLOOR_DIV) \ + cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ + cb(FuseAddReluOp, FUSE_ADD_RELU) \ + cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ + cb(FuseAddTanhOp, FUSE_ADD_TANH) \ + cb(HswishGradOp, H_SWISH_GRAD) \ + cb(LeqOp, LEQ) \ + cb(LogSumExpOp, LOG_SUM_EXP) \ + cb(LtOp, LT) \ cb(MaxOp, MAX) \ cb(MinOp, MIN) \ cb(ModOp, MOD) \ - cb(SubOp, SUB) \ cb(MulOp, MUL) \ - cb(TrueDivOp, TRUE_DIV) \ cb(PowOp, POW) \ cb(SigmoidGradOp, SIGMOID_GRAD) \ + cb(SubOp, SUB) \ cb(SwishGt0Op, SWITCH_GT0) \ cb(TanhGradOp, TANH_GRAD) \ - cb(LtOp, LT) \ - cb(LeqOp, LEQ) \ - cb(EqOp, EQ) \ - cb(FuseAddReluOp, FUSE_ADD_RELU) \ - cb(LogSumExpOp, LOG_SUM_EXP) \ - cb(FuseAddTanhOp, FUSE_ADD_TANH) \ - cb(FastTanhGradOp, FAST_TANH_GRAD) \ - cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ - cb(HswishGradOp, H_SWISH_GRAD) \ - cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ - cb(Atan2Op, ATAN2) + cb(TrueDivOp, TRUE_DIV) #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ cb(CondLeqMovOp, COND_LEQ_MOV) \ @@ -83,432 +78,19 @@ namespace mgb { namespace jit { -template -struct StandardOp; - -#define cb(mgb_op, fun) \ - template <> \ - struct StandardOp { \ - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ - ValueRange operands) { \ - ValueBuilderHelper helper(builder, loc); \ - return helper.fun(operands); \ - } \ - } - -//! unary -cb(AbsOp, abs); -cb(NegOp, neg); -cb(ExpOp, exp); -cb(CosOp, cos); -cb(CeilOp, ceil); -cb(FloorOp, floor); -cb(LogOp, log); -cb(SinOp, sin); -cb(TanhOp, tanh); - -//! binary -cb(AddOp, add); -cb(MaxOp, max); -cb(MinOp, min); -cb(SubOp, sub); -cb(MulOp, mul); -cb(ModOp, mod); -cb(TrueDivOp, div); - -#undef cb - -/////////////////////////// unary op /////////////////////////// -//! max(x, 0) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.max(operands[0], helper.const_val(0.f)); - } -}; - -//! x * (27.f + x * x) / (27.f + 9.f * x * x); -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto square = helper.mul(operands[0], operands[0]); - return helper.div( - helper.mul(operands[0], - helper.add(helper.const_val(27.f), square)), - helper.add(helper.const_val(27.f), - helper.mul(helper.const_val(9.f), square))); - } -}; - -//! x * clip(x + 3, 0, 6) / 6 -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - - auto const_3 = helper.const_val(3.f); - auto const_0 = helper.const_val(0.f); - auto const_6 = helper.const_val(6.f); - auto tmp = helper.add(operands[0], const_3); - return helper.div( - helper.mul(operands[0], - helper.min(helper.max(tmp, const_0), const_6)), - const_6); - } -}; - -//! log(1 + p) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.log(helper.add(operands[0], helper.const_val(1.f))); - } -}; - -//! 1.f / (expf(-y) + 1.f)) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.div(helper.const_val(1.f), - helper.add(helper.exp(helper.neg(operands[0])), - helper.const_val(1.f))); - } -}; - -//! exp(x) - 1 -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.sub(helper.exp(operands[0]), helper.const_val(1.f)); - } -}; - -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.select( - helper.gt(operands[0], helper.const_val(0.f)), - helper.floor(helper.add(operands[0], helper.const_val(0.5f))), - helper.ceil(helper.sub(operands[0], helper.const_val(0.5f)))); - } -}; - -//! pi / 2 - arctan2(x, sqrt(1 - x * x)) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto x = operands[0]; - auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); - auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); - auto pi_over_2 = helper.const_val(1.57079637f); - return helper.sub(pi_over_2, asin); - } -}; - -//! arctan2(x, sqrt(1 - x * x)) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto x = operands[0]; - auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); - return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); - } -}; - -//! gauss error function -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return erf_approx(helper, operands[0]); - } -}; - -//! inverse of gauss error function -//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto sqrt2 = helper.const_val(1.4142135623f); - auto x = helper.mul(helper.const_val(0.5f), - helper.add(operands[0], helper.const_val(1.f))); - return helper.div(ndtri_approx(helper, x), sqrt2); - } -}; - -//! complementary error function -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.sub(helper.const_val(1.f), erf_approx(helper, operands[0])); - } -}; - -//! inverse of complementary gauss error function -//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto minus_sqrt2 = helper.const_val(-1.4142135623f); - auto x = helper.mul(helper.const_val(0.5f), operands[0]); - return helper.div(ndtri_approx(helper, x), minus_sqrt2); - } -}; - -/////////////////////////// binary op /////////////////////////// - -//! binary: x > 0 ? y : -y -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.select(helper.gt(operands[0], helper.const_val(0.f)), - operands[1], helper.neg(operands[1])); - } -}; - -//! x^y = exp(y * log(x)) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); - } -}; - -//! x * (1 - x) * y -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.mul( - helper.mul(operands[0], - helper.sub(helper.const_val(1.f), operands[0])), - operands[1]); - } -}; - -//! (x > 0) * y -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.select(helper.gt(operands[0], helper.const_val(0.f)), - operands[1], helper.const_val(0.f)); - } -}; - -//! (1 - x * x) * y -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.mul(helper.sub(helper.const_val(1.0f), - helper.mul(operands[0], operands[0])), - operands[1]); - } -}; - -#define cb(op, fun) \ - template <> \ - struct StandardOp { \ - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ - ValueRange operands) { \ - ValueBuilderHelper helper(builder, loc); \ - return helper.select(helper.fun(operands[0], operands[1]), \ - helper.const_val(1.f), \ - helper.const_val(0.f)); \ - } \ - } - -cb(LtOp, lt); -cb(LeqOp, le); -cb(EqOp, eq); -#undef cb - -//! (x + y) <= ctype(0) ? ctype(0) : (x + y) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto sum = helper.add(operands[0], operands[1]); - return helper.max(sum, helper.const_val(0.f)); - } -}; - -//! log(exp(x) + exp(y)) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.log( - helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); - } -}; - -//! floor(x/y) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.floor(helper.div(operands[0], operands[1])); - } -}; - -//! tanh(x + y) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.tanh(helper.add(operands[0], operands[1])); - } -}; - -//! ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x * x) * y -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto x_pow2 = helper.mul(operands[0], operands[0]); - auto deno = helper.add(helper.const_val(3.f), x_pow2); - return helper.mul( - helper.div( - helper.add( - helper.add( - helper.div(helper.mul(helper.const_val( - -48.f), - x_pow2), - deno), - helper.const_val(27.f)), - x_pow2), - helper.mul(deno, helper.const_val(9.f))), - operands[1]); - } -}; - -//! 1.f / (expf(-(x+y)) + 1.f)) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.div(helper.const_val(1.f), - helper.add(helper.exp(helper.neg(helper.add( - operands[0], operands[1]))), - helper.const_val(1.f))); - } -}; - -//! x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.select( - helper.lt(operands[0], helper.const_val(-3.f)), - helper.const_val(0.f), - helper.select( - helper.gt(operands[0], helper.const_val(3.f)), - operands[1], - helper.mul( - helper.div( - helper.add(helper.mul(helper.const_val( - 2.f), - operands[0]), - helper.const_val(3.f)), - helper.const_val(6.f)), - operands[1]))); - } -}; - -//! (x+y) * min(max(x + y + 3, 0), 6) * (1/6) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - auto sum = helper.add(operands[0], operands[1]); - - auto const_3 = helper.const_val(3.f); - auto const_0 = helper.const_val(0.f); - auto const_6 = helper.const_val(6.f); - auto tmp = helper.add(sum, const_3); - return helper.div( - helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), - const_6); - } -}; - -//! arctan -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return atan2_approx(helper, operands[0], operands[1]); - } -}; - -/////////////////////////// ternary op /////////////////////////// -//! x <= y ? z : ctype(0) -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.select(helper.le(operands[0], operands[1]), operands[2], - helper.const_val(0.f)); - } -}; +mlir::Value lower_elemwise_to_std(mlir::Operation* op, + mlir::OpBuilder& builder, + mlir::Location loc, + mlir::ValueRange operands); -//! x * y + z -template <> -struct StandardOp { - mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, - ValueRange operands) { - ValueBuilderHelper helper(builder, loc); - return helper.add(helper.mul(operands[0], operands[1]), operands[2]); - } -}; +mlir::Value lower_typecvt_to_std(mlir::Operation* op, + mlir::OpBuilder& builder, + mlir::Location loc, + mlir::Value input); } // namespace jit } // namespace mgb -#endif // MGB_JIT_MLIR +#endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/interfaces.td b/src/jit/impl/mlir/ir/interfaces.td deleted file mode 100644 index e5ca677830aedc4df8469f43b814c08c42f5660c..0000000000000000000000000000000000000000 --- a/src/jit/impl/mlir/ir/interfaces.td +++ /dev/null @@ -1,33 +0,0 @@ -/** - * \file src/jit/impl/mlir/ir/interfaces.td - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#ifndef MGB_MLIR_INTERFACES -#define MGB_MLIR_INTERFACES - -#ifndef OP_BASE -include "mlir/IR/OpBase.td" -#endif - -def GenericBuilderInterface : OpInterface<"GenericBuilder"> { - let methods = [ - StaticInterfaceMethod<"TODO", "Type", "getResultType", (ins "ArrayRef":$operands)>, - StaticInterfaceMethod<"TODO", "Operation*", "create", (ins - "OpBuilder*":$builder, - "Location":$loc, - "ArrayRef":$operands - )>, - ]; -} - -def ElemwiseOpInterface : OpInterface<"ElemwiseOp">; - -#endif diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp index dee486a8ad6e3c7e4da6759141c988012018e6bc..d3c16a3846854b2df4a2c182e39daa6943c0d0d2 100644 --- a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -13,18 +13,19 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR +#include "./common.h" +#include "./each_mode.h" + #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" #include "megbrain/jit/mlir/ir/utils.h" -#include "./each_mode.h" - #include #include +#include #include #include -#include "mlir/IR/StandardTypes.h" using namespace mgb; using namespace jit; @@ -57,41 +58,10 @@ void lower_op_to_loops(Operation* op, ValueRange operands, rewriter.replaceOp(op, alloc); } -template -struct UnaryOpLowering : public ConversionPattern { - UnaryOpLowering(MLIRContext* ctx) - : ConversionPattern(Op::getOperationName(), 1, ctx) {} - - LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - auto loc = op->getLoc(); - lower_op_to_loops( - op, operands, rewriter, - [loc](OpBuilder& builder, ValueRange memref_operands, - ValueRange loop_ivs) { - typename Op::Adaptor binary_adaptor(memref_operands); - LoweredOp lower_op; - - auto loaded_lhs = get_operand( - builder, loc, binary_adaptor.lhs(), loop_ivs); - - return lower_op(builder, loc, {loaded_lhs}); - }); - return success(); - } -}; - -#define cb(_op, _) \ - using _op##Lowering = UnaryOpLowering>; -MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) -#undef cb - -template -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext* ctx) - : ConversionPattern(Op::getOperationName(), 1, ctx) {} - +struct ElemwiseLowering : public ConversionPattern { + ElemwiseLowering(MLIRContext* ctx) + : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, + ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { @@ -101,83 +71,51 @@ struct BinaryOpLowering : public ConversionPattern { dst_layout.init_contiguous_stride(); lower_op_to_loops( op, operands, rewriter, - [dst_layout, loc, this](OpBuilder& builder, - ValueRange memref_operands, - ValueRange loop_ivs) { - typename Op::Adaptor binary_adaptor(memref_operands); - LoweredOp lower_op; - - auto loaded_lhs = get_affine_load_op(builder, loc, - binary_adaptor.lhs(), - loop_ivs, dst_layout); - auto loaded_rhs = get_affine_load_op(builder, loc, - binary_adaptor.rhs(), - loop_ivs, dst_layout); - - return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); + [dst_layout, loc, op](OpBuilder& builder, + ValueRange memref_operands, + ValueRange loop_ivs) { + auto inputs = llvm::to_vector<4>(llvm::map_range( + memref_operands, [&](mlir::Value val) { + return get_affine_load_op(builder, loc, val, + loop_ivs, dst_layout); + })); + return lower_elemwise_to_std(op, builder, loc, inputs); }); return success(); } }; -#define cb(_op, _) \ - using _op##Lowering = BinaryOpLowering>; -MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) -#undef cb - -template -struct TernaryOpLowering : public ConversionPattern { - TernaryOpLowering(MLIRContext* ctx) - : ConversionPattern(Op::getOperationName(), 1, ctx) {} - +struct TypeCvtLowering : public ConversionPattern { + TypeCvtLowering(MLIRContext* ctx) + : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, + ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); - auto dst_memref_type = (*op->result_type_begin()).cast(); - megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); - dst_layout.init_contiguous_stride(); lower_op_to_loops( op, operands, rewriter, - [dst_layout, loc](OpBuilder& builder, - ValueRange memref_operands, - ValueRange loop_ivs) { - typename Op::Adaptor ternary_adaptor(memref_operands); - LoweredOp lower_op; - - auto loaded_x = get_affine_load_op(builder, loc, - ternary_adaptor.x(), - loop_ivs, dst_layout); - auto loaded_y = get_affine_load_op(builder, loc, - ternary_adaptor.y(), - loop_ivs, dst_layout); - auto loaded_z = get_affine_load_op(builder, loc, - ternary_adaptor.z(), - loop_ivs, dst_layout); - - return lower_op(builder, loc, - {loaded_x, loaded_y, loaded_z}); + [loc, op](OpBuilder& builder, ValueRange memref_operands, + ValueRange loop_ivs) { + mlir::Value input = get_operand( + builder, loc, memref_operands[0], loop_ivs); + return lower_typecvt_to_std(op, builder, loc, input); }); return success(); } }; -#define cb(_op, _) \ - using _op##Lowering = \ - TernaryOpLowering>; -MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) -#undef cb - struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx) - : ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} + : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { + } LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); auto memref_type = operands[0].getType().cast(); - AssignOpAdaptor assign_adaptor(operands); + dialect::AssignOpAdaptor assign_adaptor(operands); llvm::SmallVector lower_bounds(memref_type.getRank(), 0); llvm::SmallVector steps(memref_type.getRank(), 1); @@ -195,10 +133,10 @@ struct AssignOpLowering : public ConversionPattern { } }; -struct ReturnOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(jit::ReturnOp op, + LogicalResult matchAndRewrite(dialect::ReturnOp op, PatternRewriter& rewriter) const final { // We lower "mgb.return" directly to "std.return". rewriter.replaceOpWithNewOp(op); @@ -207,12 +145,12 @@ struct ReturnOpLowering : public OpRewritePattern { }; struct ConstantScalarOpLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(jit::ConstantScalarOp op, + LogicalResult matchAndRewrite(dialect::ConstantScalarOp op, PatternRewriter& rewriter) const final { - ConstantScalarOpAdaptor constant_scalar_adaptor(op); + dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op); rewriter.replaceOpWithNewOp( op, constant_scalar_adaptor.value()); return success(); @@ -234,14 +172,9 @@ public: target.addIllegalDialect(); OwningRewritePatternList patterns; -#define cb(_op, _) _op##Lowering, - patterns.insert( &getContext()); -#undef cb if (failed(applyPartialConversion(getFunction(), target, patterns))) { signalPassFailure(); diff --git a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp index 036838726a2eeb8beff06a45dd91356d1423c7e8..e90e1be16f177c4705e9fa74f9b891252149b350 100644 --- a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp @@ -13,12 +13,19 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR +#include "./common.h" #include "./each_mode.h" + #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" #include "megbrain/jit/mlir/ir/utils.h" +#include +#include +#include +#include +#include #include #include #include @@ -27,12 +34,6 @@ #include #include -#include -#include -#include -#include -#include - using namespace mgb; using namespace jit; @@ -59,7 +60,7 @@ megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) { block_iter++) { for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); op_iter++) { - auto op = llvm::dyn_cast_or_null(&(*op_iter)); + auto op = llvm::dyn_cast_or_null(&(*op_iter)); if (op && op.getNumOperands() > 0) { return mlir_type_to_layout(*(op.operand_type_begin())); } @@ -81,64 +82,27 @@ std::vector get_multidim_tid(ConversionPatternRewriter& rewriter, idxs.resize(dst.ndim); mlir::Value dim_index = index; for (int i = dst.ndim - 1; i >= 0; i--) { - auto cur_index = helper.modI(dim_index, helper.constI(dst[i])); + auto cur_index = helper.modI(dim_index, helper.const_i32(dst[i])); idxs[i] = cur_index; - dim_index = helper.divI(dim_index, helper.constI(dst[i])); + dim_index = helper.divI(dim_index, helper.const_i32(dst[i])); } megdnn::TensorLayout src_layout = mlir_type_to_layout(type); src_layout.init_contiguous_stride(); for (int i = 0; i < type.getRank(); ++i) { if (src_layout[i] == 1) { - idxs[i] = helper.constI(0); + idxs[i] = helper.const_i32(0); } } return idxs; } else { return {index}; } - } -template -struct UnaryOpLowering : public ConversionPattern { - UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) - : ConversionPattern(Op::getOperationName(), 1, ctx), - m_launch_op{launch_op} {} - - LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - auto loc = op->getLoc(); - - typename Op::Adaptor binary_adaptor(operands); - rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); - - auto dst_layout = output_layout(m_launch_op); - auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), - dst_layout); - auto loaded_lhs = - get_operand(rewriter, loc, binary_adaptor.lhs(), index); - - LoweredOp lower_op; - - rewriter.replaceOp(op, lower_op(rewriter, loc, {loaded_lhs})); - return success(); - } - -private: - gpu::LaunchOp& m_launch_op; -}; - -#define cb(_op, _) \ - using _op##Lowering = UnaryOpLowering>; -MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) -#undef cb - -template -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) - : ConversionPattern(Op::getOperationName(), 1, ctx), +struct ElemwiseLowering : public ConversionPattern { + ElemwiseLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) + : ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( @@ -146,23 +110,18 @@ struct BinaryOpLowering : public ConversionPattern { ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); - typename Op::Adaptor binary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); auto dst_layout = output_layout(m_launch_op); - auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), - dst_layout); - auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(), - dst_layout); - auto loaded_lhs = get_operand(rewriter, loc, - binary_adaptor.lhs(), lhs_index); - auto loaded_rhs = get_operand(rewriter, loc, - binary_adaptor.rhs(), rhs_index); - - LoweredOp lower_op; + auto inputs = llvm::to_vector<4>( + llvm::map_range(operands, [&](mlir::Value val) { + auto index = + get_multidim_tid(rewriter, loc, val, dst_layout); + return get_operand(rewriter, loc, val, index); + })); rewriter.replaceOp(op, - lower_op(rewriter, loc, {loaded_lhs, loaded_rhs})); + lower_elemwise_to_std(op, rewriter, loc, inputs)); return success(); } @@ -170,43 +129,22 @@ private: gpu::LaunchOp& m_launch_op; }; -#define cb(_op, _) \ - using _op##Lowering = BinaryOpLowering>; -MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) -#undef cb - -template -struct TernaryOpLowering : public ConversionPattern { - TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) - : ConversionPattern(Op::getOperationName(), 1, ctx), +struct TypeCvtLowering : public ConversionPattern { + TypeCvtLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) + : ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx), m_launch_op{launch_op} {} - LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); - typename Op::Adaptor ternary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); auto dst_layout = output_layout(m_launch_op); - auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(), - dst_layout); - auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(), - dst_layout); - auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(), - dst_layout); - auto loaded_x = get_operand(rewriter, loc, ternary_adaptor.x(), - index_x); - auto loaded_y = get_operand(rewriter, loc, ternary_adaptor.y(), - index_y); - auto loaded_z = get_operand(rewriter, loc, ternary_adaptor.z(), - index_z); - - LoweredOp lower_op; - - rewriter.replaceOp( - op, lower_op(rewriter, loc, {loaded_x, loaded_y, loaded_z})); + auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout); + auto input = get_operand(rewriter, loc, operands[0], index); + + rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input)); return success(); } @@ -214,15 +152,9 @@ private: gpu::LaunchOp& m_launch_op; }; -#define cb(_op, _) \ - using _op##Lowering = \ - TernaryOpLowering>; -MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) -#undef cb - struct ReturnOpLowering : public ConversionPattern { ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) - : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), + : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( @@ -270,14 +202,14 @@ private: }; struct ConstantScalarOpLowering - : public OpRewritePattern { + : public OpRewritePattern { ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) - : OpRewritePattern(ctx), + : OpRewritePattern(ctx), m_launch_op{launch_op} {} - LogicalResult matchAndRewrite(jit::ConstantScalarOp op, + LogicalResult matchAndRewrite(dialect::ConstantScalarOp op, PatternRewriter& rewriter) const final { - ConstantScalarOpAdaptor constant_scalar_adaptor(op); + dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); rewriter.replaceOpWithNewOp( @@ -291,7 +223,7 @@ private: struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) - : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), + : ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( @@ -299,7 +231,7 @@ struct AssignOpLowering : public ConversionPattern { ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); - AssignOpAdaptor assign_adaptor(operands); + dialect::AssignOpAdaptor assign_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); auto dst_layout = output_layout(m_launch_op); @@ -343,14 +275,9 @@ public: target.addLegalDialect(); target.addIllegalDialect(); -#define cb(_op, _) _op##Lowering, - patterns.insert( &getContext(), launch_op); -#undef cb if (failed(applyPartialConversion(func_op, target, patterns))) { signalPassFailure(); diff --git a/src/jit/impl/mlir/ir/numerical.cpp b/src/jit/impl/mlir/ir/numerical.cpp index ee66cd9b9830c4c195886a6dac3df34dcee8e8ef..4d06ae6c2ac3053c1e0cbb76a9a5a60f95018aff 100644 --- a/src/jit/impl/mlir/ir/numerical.cpp +++ b/src/jit/impl/mlir/ir/numerical.cpp @@ -22,7 +22,7 @@ mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, std::vector& coeff) { size_t n = coeff.size(); if (n == 0) { - return helper.const_val(0); + return helper.const_f32(0); } mlir::Value r = coeff[0]; @@ -40,23 +40,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, mlir::Value x) { auto atan_poly = [&](mlir::Value t) { std::vector coeff = { - helper.const_val(2.90188402868807315826416015625E-3), - helper.const_val(-1.62907354533672332763671875E-2), - helper.const_val(4.3082617223262786865234375E-2), - helper.const_val(-7.5408883392810821533203125E-2), - helper.const_val(0.1066047251224517822265625), - helper.const_val(-0.14209578931331634521484375), - helper.const_val(0.19993579387664794921875), - helper.const_val(-0.3333314359188079833984375)}; + helper.const_f32(2.90188402868807315826416015625E-3), + helper.const_f32(-1.62907354533672332763671875E-2), + helper.const_f32(4.3082617223262786865234375E-2), + helper.const_f32(-7.5408883392810821533203125E-2), + helper.const_f32(0.1066047251224517822265625), + helper.const_f32(-0.14209578931331634521484375), + helper.const_f32(0.19993579387664794921875), + helper.const_f32(-0.3333314359188079833984375)}; auto t2 = helper.mul(t, t); auto p = polynomial(helper, t2, coeff); return helper.add(helper.mul(helper.mul(p, t2), t), t); }; // constants - auto zero = helper.const_val(0); - auto pi = helper.const_val(3.141592653589793); - auto pi_over_2 = helper.const_val(1.570796326794897); + auto zero = helper.const_f32(0); + auto pi = helper.const_f32(3.141592653589793); + auto pi_over_2 = helper.const_f32(1.570796326794897); // transform the angle into interval [0, pi/4] auto ax = helper.abs(x); @@ -83,23 +83,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, // original book: // Numerical Recipes in Fortran 77: The Art of Scientific Computing mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { - auto zero = helper.const_val(0); - auto one = helper.const_val(1); - auto half = helper.const_val(0.5); + auto zero = helper.const_f32(0); + auto one = helper.const_f32(1); + auto half = helper.const_f32(0.5); auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); std::vector coeff = { - helper.const_val(0.17087277), - helper.const_val(-0.82215223), - helper.const_val(1.48851587), - helper.const_val(-1.13520398), - helper.const_val(0.27886807), - helper.const_val(-0.18628806), - helper.const_val(0.09678418), - helper.const_val(0.37409196), - helper.const_val(1.00002368), - helper.const_val(-1.26551223)}; + helper.const_f32(0.17087277), + helper.const_f32(-0.82215223), + helper.const_f32(1.48851587), + helper.const_f32(-1.13520398), + helper.const_f32(0.27886807), + helper.const_f32(-0.18628806), + helper.const_f32(0.09678418), + helper.const_f32(0.37409196), + helper.const_f32(1.00002368), + helper.const_f32(-1.26551223)}; auto p = polynomial(helper, t, coeff); auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); @@ -130,25 +130,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { // polynomial P auto P = [&](mlir::Value i, mlir::Value cond) { std::vector coeff0 = { - helper.const_val(4.05544892305962419923E0), - helper.const_val(3.15251094599893866154E1), - helper.const_val(5.71628192246421288162E1), - helper.const_val(4.40805073893200834700E1), - helper.const_val(1.46849561928858024014E1), - helper.const_val(2.18663306850790267539E0), - helper.const_val(-1.40256079171354495875E-1), - helper.const_val(-3.50424626827848203418E-2), - helper.const_val(-8.57456785154685413611E-4)}; + helper.const_f32(4.05544892305962419923E0), + helper.const_f32(3.15251094599893866154E1), + helper.const_f32(5.71628192246421288162E1), + helper.const_f32(4.40805073893200834700E1), + helper.const_f32(1.46849561928858024014E1), + helper.const_f32(2.18663306850790267539E0), + helper.const_f32(-1.40256079171354495875E-1), + helper.const_f32(-3.50424626827848203418E-2), + helper.const_f32(-8.57456785154685413611E-4)}; std::vector coeff1 = { - helper.const_val(3.23774891776946035970E0), - helper.const_val(6.91522889068984211695E0), - helper.const_val(3.93881025292474443415E0), - helper.const_val(1.33303460815807542389E0), - helper.const_val(2.01485389549179081538E-1), - helper.const_val(1.23716634817820021358E-2), - helper.const_val(3.01581553508235416007E-4), - helper.const_val(2.65806974686737550832E-6), - helper.const_val(6.23974539184983293730E-9)}; + helper.const_f32(3.23774891776946035970E0), + helper.const_f32(6.91522889068984211695E0), + helper.const_f32(3.93881025292474443415E0), + helper.const_f32(1.33303460815807542389E0), + helper.const_f32(2.01485389549179081538E-1), + helper.const_f32(1.23716634817820021358E-2), + helper.const_f32(3.01581553508235416007E-4), + helper.const_f32(2.65806974686737550832E-6), + helper.const_f32(6.23974539184983293730E-9)}; return helper.select(cond, polynomial(helper, i, coeff0), polynomial(helper, i, coeff1)); @@ -157,25 +157,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { // polynomial Q auto Q = [&](mlir::Value i, mlir::Value cond) { std::vector coeff0 = { - helper.const_val(1.f), - helper.const_val(1.57799883256466749731E1), - helper.const_val(4.53907635128879210584E1), - helper.const_val(4.13172038254672030440E1), - helper.const_val(1.50425385692907503408E1), - helper.const_val(2.50464946208309415979E0), - helper.const_val(-1.42182922854787788574E-1), - helper.const_val(-3.80806407691578277194E-2), - helper.const_val(-9.33259480895457427372E-4)}; + helper.const_f32(1.f), + helper.const_f32(1.57799883256466749731E1), + helper.const_f32(4.53907635128879210584E1), + helper.const_f32(4.13172038254672030440E1), + helper.const_f32(1.50425385692907503408E1), + helper.const_f32(2.50464946208309415979E0), + helper.const_f32(-1.42182922854787788574E-1), + helper.const_f32(-3.80806407691578277194E-2), + helper.const_f32(-9.33259480895457427372E-4)}; std::vector coeff1 = { - helper.const_val(1.f), - helper.const_val(6.02427039364742014255E0), - helper.const_val(3.67983563856160859403E0), - helper.const_val(1.37702099489081330271E0), - helper.const_val(2.16236993594496635890E-1), - helper.const_val(1.34204006088543189037E-2), - helper.const_val(3.28014464682127739104E-4), - helper.const_val(2.89247864745380683936E-6), - helper.const_val(6.79019408009981274425E-9)}; + helper.const_f32(1.f), + helper.const_f32(6.02427039364742014255E0), + helper.const_f32(3.67983563856160859403E0), + helper.const_f32(1.37702099489081330271E0), + helper.const_f32(2.16236993594496635890E-1), + helper.const_f32(1.34204006088543189037E-2), + helper.const_f32(3.28014464682127739104E-4), + helper.const_f32(2.89247864745380683936E-6), + helper.const_f32(6.79019408009981274425E-9)}; return helper.select(cond, polynomial(helper, i, coeff0), polynomial(helper, i, coeff1)); @@ -184,37 +184,37 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { // polynomial R auto R = [&](mlir::Value i) { std::vector coeff = { - helper.const_val(-5.99633501014107895267E1), - helper.const_val(9.80010754185999661536E1), - helper.const_val(-5.66762857469070293439E1), - helper.const_val(1.39312609387279679503E1), - helper.const_val(-1.23916583867381258016E0)}; + helper.const_f32(-5.99633501014107895267E1), + helper.const_f32(9.80010754185999661536E1), + helper.const_f32(-5.66762857469070293439E1), + helper.const_f32(1.39312609387279679503E1), + helper.const_f32(-1.23916583867381258016E0)}; return polynomial(helper, i, coeff); }; // polynomial S auto S = [&](mlir::Value i) { std::vector coeff = { - helper.const_val(1.f), - helper.const_val(1.95448858338141759834E0), - helper.const_val(4.67627912898881538453E0), - helper.const_val(8.63602421390890590575E1), - helper.const_val(-2.25462687854119370527E2), - helper.const_val(2.00260212380060660359E2), - helper.const_val(-8.20372256168333339912E1), - helper.const_val(1.59056225126211695515E1), - helper.const_val(-1.18331621121330003142E0)}; + helper.const_f32(1.f), + helper.const_f32(1.95448858338141759834E0), + helper.const_f32(4.67627912898881538453E0), + helper.const_f32(8.63602421390890590575E1), + helper.const_f32(-2.25462687854119370527E2), + helper.const_f32(2.00260212380060660359E2), + helper.const_f32(-8.20372256168333339912E1), + helper.const_f32(1.59056225126211695515E1), + helper.const_f32(-1.18331621121330003142E0)}; return polynomial(helper, i, coeff); }; // constants - auto zero = helper.const_val(0); - auto one = helper.const_val(1); - auto half = helper.const_val(0.5); - auto eight = helper.const_val(8); - auto minus_2 = helper.const_val(-2); - auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2) - auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi) + auto zero = helper.const_f32(0); + auto one = helper.const_f32(1); + auto half = helper.const_f32(0.5); + auto eight = helper.const_f32(8); + auto minus_2 = helper.const_f32(-2); + auto exp_minus_2 = helper.const_f32(0.135335283236); // exp(-2) + auto sqrt_2pi = helper.const_f32(2.506628274631); // sqrt(2pi) // conditions auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) diff --git a/src/jit/impl/mlir/ir/ops.td b/src/jit/impl/mlir/ir/ops.td deleted file mode 100644 index 960897604e167aab67abaab637041a65cf56c170..0000000000000000000000000000000000000000 --- a/src/jit/impl/mlir/ir/ops.td +++ /dev/null @@ -1,216 +0,0 @@ -/** - * \file src/jit/impl/mlir/ir/ops.td - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#ifndef MGB_MLIR_OPS -#define MGB_MLIR_OPS - -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -include "./interfaces.td" -include "./predicates.td" - -def Mgb_Dialect : Dialect { - let name = "mgb"; - let cppNamespace = "mgb::jit"; -} - -class ElemwiseBuilderImpl { - code ElemwiseBuilderImpl_create = [{ - static Operation* create(OpBuilder* builder, Location loc, ValueRange operands) { - OperationState state(loc, getOperationName()); - state.addOperands(operands); - state.addTypes(getResultType(operands)); - return builder->createOperation(state); - } - }]; -} - -class ElemwiseOp traits = [NoSideEffect]> : - Op, ElemwiseBuilderImpl; - -class GenericOp traits = []> : - Op; - -class ElemwiseUnaryOp traits = [NoSideEffect]> : - ElemwiseOp { - let arguments = (ins F32MemRef:$lhs); - let results = (outs F32MemRef); - - let builders = [OpBuilder< - "Builder* builder, OperationState& result, ValueRange operands", [{ - result.addOperands(operands); - result.addTypes(getResultType(operands)); - }]>, OpBuilder < - "OpBuilder& builder, OperationState& result, Value lhs", [{ - result.addOperands(lhs); - result.addTypes(getResultType({lhs})); - }] - >]; - - let extraClassDeclaration = [{ - static Type getResultType(ValueRange operands) { - return deduce_result_type(operands); - } - }] # ElemwiseBuilderImpl_create; -} - -def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; -def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; -def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; -def AcosOp : ElemwiseUnaryOp<"acos", [NoSideEffect]>; -def AsinOp : ElemwiseUnaryOp<"asin", [NoSideEffect]>; -def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; -def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; -def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; -def ExpM1Op : ElemwiseUnaryOp<"expm1", [NoSideEffect]>; -def FloorOp : ElemwiseUnaryOp<"floor", [NoSideEffect]>; -def LogOp : ElemwiseUnaryOp<"log", [NoSideEffect]>; -def Log1POp : ElemwiseUnaryOp<"log1p", [NoSideEffect]>; -def SigmoidOp: ElemwiseUnaryOp<"sigmoid", [NoSideEffect]>; -def SinOp : ElemwiseUnaryOp<"sin", [NoSideEffect]>; -def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; -def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; -def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; -def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; -def ErfOp : ElemwiseUnaryOp<"erf", [NoSideEffect]>; -def ErfInvOp : ElemwiseUnaryOp<"erfinv", [NoSideEffect]>; -def ErfCOp : ElemwiseUnaryOp<"erfc", [NoSideEffect]>; -def ErfCInvOp : ElemwiseUnaryOp<"erfcinv", [NoSideEffect]>; - -class ElemwiseBinaryOp traits = [NoSideEffect]> : - ElemwiseOp { - - let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs); - let results = (outs F32MemRef); - - let builders = [OpBuilder< - "Builder* builder, OperationState& result, ValueRange operands", [{ - result.addOperands(operands); - result.addTypes(getResultType(operands)); - }] - >, OpBuilder < - "OpBuilder& builder, OperationState& result, Value lhs, Value rhs", [{ - result.addOperands(lhs); - result.addOperands(rhs); - result.addTypes(getResultType({lhs, rhs})); - }] - >]; - - let extraClassDeclaration = [{ - static Type getResultType(ValueRange operands) { - return deduce_result_type(operands); - } - }] # ElemwiseBuilderImpl_create; -} - -def AbsGradOp : ElemwiseBinaryOp<"abs_grad", [NoSideEffect]>; -def AddOp : ElemwiseBinaryOp<"add", [Commutative, NoSideEffect]>; -def FloorDivOp : ElemwiseBinaryOp<"floor_div", [NoSideEffect]>; -def MaxOp : ElemwiseBinaryOp<"max", [Commutative, NoSideEffect]>; -def MinOp : ElemwiseBinaryOp<"min", [Commutative, NoSideEffect]>; -def ModOp : ElemwiseBinaryOp<"mod", [NoSideEffect]>; -def MulOp : ElemwiseBinaryOp<"mul", [Commutative, NoSideEffect]>; -def SubOp : ElemwiseBinaryOp<"sub", [NoSideEffect]>; -def SigmoidGradOp : ElemwiseBinaryOp<"sigmoid_grad", [NoSideEffect]>; -def SwishGt0Op : ElemwiseBinaryOp<"switch_gt0", [NoSideEffect]>; -def TanhGradOp : ElemwiseBinaryOp<"tanh_grad", [NoSideEffect]>; -def LtOp : ElemwiseBinaryOp<"lt", [NoSideEffect]>; -def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; -def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; -def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; -def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; -def PowOp : ElemwiseBinaryOp<"pow", [NoSideEffect]>; -def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; -def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; -def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; -def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; -def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; -def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; -def Atan2Op : ElemwiseBinaryOp<"atan2", [NoSideEffect]>; - -class ElemwiseTernaryOp traits = [NoSideEffect]> : - ElemwiseOp { - - let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z); - let results = (outs F32MemRef); - - let builders = [OpBuilder< - "Builder* builder, OperationState& result, ValueRange operands", [{ - result.addOperands(operands); - result.addTypes(getResultType(operands)); - }] - >, OpBuilder < - "OpBuilder& builder, OperationState& result, Value x, Value y, Value z", [{ - result.addOperands(x); - result.addOperands(y); - result.addOperands(z); - result.addTypes(getResultType({x, y, z})); - }] - >]; - - let extraClassDeclaration = [{ - static Type getResultType(ValueRange operands) { - return deduce_result_type(operands); - } - }] # ElemwiseBuilderImpl_create; -} - -def CondLeqMovOp: ElemwiseTernaryOp<"cond_leq_mov", [NoSideEffect]>; -def FuseMulAdd3Op: ElemwiseTernaryOp<"fuse_mul_add3", [NoSideEffect]>; - -def ReturnOp : GenericOp<"return", - [NoSideEffect, HasParent<"FuncOp">, Terminator]> { - let summary = "return operation"; - let description = [{ - The "return" operation represents a return operation within a function. - The operation takes an no tensor operand and produces no results. - }]; - - // The return operation takes an optional input operand to return. This - // value must match the return type of the enclosing function. - let arguments = (ins); - - // The return operation only emits the input in the format if it is present. - let assemblyFormat = "attr-dict"; -} - -def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { - let summary = "scalar constant"; - let arguments = (ins AnyAttr:$value); - let results = (outs F32:$result); - - let builders = [OpBuilder< - "Builder* builder, OperationState& result, float value", [{ - result.addAttribute("value", builder->getF32FloatAttr(value)); - result.addTypes(builder->getF32Type()); - }] - >]; - - let extraClassDeclaration = [{ - Attribute getValue() { return getAttr("value"); } - FloatAttr getFloatAttr() { return getAttrOfType("value"); } - }]; - -} - -def AssignOp : GenericOp<"assign", []> { - let summary = "assign op"; - let description = [{ - assign rhs to lhs without results - }]; - - let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); -} - -#endif diff --git a/src/jit/impl/mlir/ir/predicates.td b/src/jit/impl/mlir/ir/predicates.td deleted file mode 100644 index 3f5633072ffa90a1fe2d5df965e0bcd69df905b9..0000000000000000000000000000000000000000 --- a/src/jit/impl/mlir/ir/predicates.td +++ /dev/null @@ -1,24 +0,0 @@ -/** - * \file src/jit/impl/mlir/ir/predicates.td - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#ifndef MGB_MLIR_PREDICATES -#define MGB_MLIR_PREDICATES - -#ifndef OP_BASE -include "mlir/IR/OpBase.td" -#endif - -def ElemwiseFloatAny : TypeConstraint< -CPred<"is_elemwise_float($_self)">, "elemwise-float">; - -#endif - diff --git a/src/jit/impl/mlir/ir/types.cpp b/src/jit/impl/mlir/ir/types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c353e6d85f8f01d2fe76d3b8bbfe15971dd0a2a2 --- /dev/null +++ b/src/jit/impl/mlir/ir/types.cpp @@ -0,0 +1,115 @@ +/** + * \file src/jit/impl/mlir/ir/types.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "./types.h" + +#include "megbrain/common.h" +#include "megbrain/exception.h" +#include "megbrain/jit/mlir/ir/utils.h" + +namespace mgb { +namespace jit { + +mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, + mlir::MLIRContext* ctx) { + switch (type.enumv()) { + case megdnn::DTypeEnum::Float32: + return mlir::FloatType::getF32(ctx); + case megdnn::DTypeEnum::Uint8: + return mlir::IntegerType::get(8, ctx); + case megdnn::DTypeEnum::Int8: + return mlir::IntegerType::get(8, ctx); + case megdnn::DTypeEnum::Int16: + return mlir::IntegerType::get(16, ctx); + case megdnn::DTypeEnum::Int32: + return mlir::IntegerType::get(32, ctx); + case megdnn::DTypeEnum::IntB1: + return mlir::IntegerType::get(1, ctx); + case megdnn::DTypeEnum::IntB2: + return mlir::IntegerType::get(2, ctx); + case megdnn::DTypeEnum::IntB4: + return mlir::IntegerType::get(4, ctx); + case megdnn::DTypeEnum::Byte: + return mlir::IntegerType::get(8, ctx); + case megdnn::DTypeEnum::Float16: + return mlir::FloatType::getF16(ctx); + case megdnn::DTypeEnum::UintB4: + return mlir::IntegerType::get(4, ctx); + case megdnn::DTypeEnum::BFloat16: + return mlir::FloatType::getBF16(ctx); + case megdnn::DTypeEnum::Bool: + return mlir::IntegerType::get(1, ctx); + default: + mgb_throw(InternalError, "Unsupported MegDNN dtype: %s", + type.name()); + } +} + +megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { + mlir::Type element_type = type; + if (auto cast = type.dyn_cast_or_null()) { + element_type = cast.getElementType(); + } + + megdnn::DTypeEnum enumv; + if (element_type.isF32()) { + enumv = megdnn::DTypeEnum::Float32; + } else if (element_type.isSignlessInteger(1)) { + enumv = megdnn::DTypeEnum::IntB1; + } else if (element_type.isSignlessInteger(2)) { + enumv = megdnn::DTypeEnum::IntB2; + } else if (element_type.isSignlessInteger(4)) { + enumv = megdnn::DTypeEnum::IntB4; + } else if (element_type.isSignlessInteger(8)) { + enumv = megdnn::DTypeEnum::Int8; + } else if (element_type.isSignlessInteger(16)) { + enumv = megdnn::DTypeEnum::Int16; + } else if (element_type.isSignlessInteger(32)) { + enumv = megdnn::DTypeEnum::Int32; + } else if (element_type.isF16()) { + enumv = megdnn::DTypeEnum::Float16; + } else if (element_type.isBF16()) { + enumv = megdnn::DTypeEnum::BFloat16; + } else if (element_type.isSignlessInteger(1)) { + enumv = megdnn::DTypeEnum::Bool; + } else { + mgb_throw(InternalError, "Unsupported MLIR Type: %s", + mlir_type_to_string(element_type).c_str()); + } + return megdnn::DType::from_enum(enumv); +} + +bool is_signed_int_dtype(megdnn::DType type) { + auto enumv = type.enumv(); + return enumv == megdnn::DTypeEnum::Int8 or + enumv == megdnn::DTypeEnum::Int16 or + enumv == megdnn::DTypeEnum::Int32 or + enumv == megdnn::DTypeEnum::IntB1 or + enumv == megdnn::DTypeEnum::IntB2 or + enumv == megdnn::DTypeEnum::IntB4; +} + +bool is_unsigned_int_dtype(megdnn::DType type) { + auto enumv = type.enumv(); + return enumv == megdnn::DTypeEnum::Uint8 or + enumv == megdnn::DTypeEnum::UintB4; +} + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/types.h b/src/jit/impl/mlir/ir/types.h index ee8a1f6152847a68fb213b1e2d4535f0a920e6af..2dff735e7a5a632f01e0c16f501dd495c5506d55 100644 --- a/src/jit/impl/mlir/ir/types.h +++ b/src/jit/impl/mlir/ir/types.h @@ -14,22 +14,33 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR + +#include "megdnn/dtype.h" + #include namespace mgb { namespace jit { -inline bool is_elemwise_float(const mlir::Type& dt) { - if (auto cast = dt.dyn_cast_or_null()) { - if (cast.getElementType().isF32()) { - return true; - } - } - if (dt.isa()) { - return true; - } - return false; -} +#define FOR_EACH_DNN_DTYPE(cb) \ + cb(Float32, dt_float32); \ + cb(Uint8, dt_uint8); \ + cb(Int8, dt_int8); \ + cb(Int16, dt_int16); \ + cb(Int32, dt_int32); \ + cb(Byte, dt_byte); \ + MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); \ + MEGDNN_INC_FLOAT16(cb(BFloat16, dt_bfloat16)); \ + cb(Bool, dt_bool); + +mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, + mlir::MLIRContext* ctx); + +megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); + +bool is_signed_int_dtype(megdnn::DType type); + +bool is_unsigned_int_dtype(megdnn::DType type); } // namespace jit } // namespace mgb diff --git a/src/jit/impl/mlir/ir/utils.cpp b/src/jit/impl/mlir/ir/utils.cpp index 794d7b6ff628349fa7f436361920c7541254faf7..24e8d070cedbd517f9ab385ef80c916c53abf2b1 100644 --- a/src/jit/impl/mlir/ir/utils.cpp +++ b/src/jit/impl/mlir/ir/utils.cpp @@ -13,11 +13,14 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR +#include "megbrain/jit/mlir/ir/utils.h" + +#include "./types.h" + #include "megbrain/common.h" #include "megbrain/exception.h" -#include "megbrain/jit/mlir/ir/utils.h" -#include "megdnn/oprs/general.h" #include "megdnn/basic_types.h" +#include "megdnn/oprs/general.h" #include #include @@ -44,7 +47,7 @@ mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, return alloc; } -mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { +mlir::Type jit::deduce_elemwise_res_type(mlir::ValueRange operands) { megdnn::TensorShapeArray srcs; megdnn::TensorShape dst; megdnn::DType dst_type; @@ -59,8 +62,8 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { } megdnn::Elemwise::deduce_shape(srcs, dst); mlir::Builder builder(operands[0].getContext()); - return layout_to_mlir_type({dst, mlir_type_to_dtype(operands[0].getType())}, - builder); + return layout_to_mlir_type( + {dst, mlir_type_to_megdnn_dtype(operands[0].getType())}, builder); } megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { @@ -72,41 +75,21 @@ megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { for (size_t i = 0; i < ret.ndim; i++) { ret.shape[i] = real_type.getDimSize(i); } - ret.dtype = mlir_type_to_dtype(real_type.getElementType()); + ret.dtype = mlir_type_to_megdnn_dtype(real_type.getElementType()); } return ret; } -megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) { - mlir::Type element_type = type; - if (auto cast = type.dyn_cast_or_null()) { - element_type = cast.getElementType(); - } - if (element_type.isF32()) { - return megdnn::dtype::Float32{}; - } else { - mgb_throw(InternalError, - "Unsupport mlir type for MemRefType, got: %s\n", - mlir_type_to_string(type).c_str()); - } - return {}; -} - mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, mlir::Builder& builder) { std::vector shape; for (size_t i = 0; i < layout.ndim; i++) { shape.push_back(layout[i]); } - switch (layout.dtype.enumv()) { - case megdnn::DTypeEnum::Float32: - return mlir::MemRefType::get(shape, builder.getF32Type()); - default: - mgb_throw(InternalError, "No supported dtype: %s", - layout.dtype.name()); - } + mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); + return mlir::MemRefType::get(shape, type); } -#endif // MGB_JIT_MLIR +#endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp index b4d7f86db7e5e49f70b23896df1ec1d790c0652e..399cabe2be017c7517b66e5f1eb2a8e915c28449 100644 --- a/src/jit/impl/mlir/mlir_gen.cpp +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -15,6 +15,7 @@ #include "./mlir_gen.h" #include "./ir/each_mode.h" +#include "./ir/types.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/utils.h" @@ -116,9 +117,9 @@ private: return nullptr; } - jit::ReturnOp return_op; + dialect::ReturnOp return_op; if (!return_op) { - m_builder.create(m_builder.getUnknownLoc()); + m_builder.create(m_builder.getUnknownLoc()); } std::string op_content = mlir_type_to_string(func_op); func_op.setName( @@ -135,9 +136,7 @@ private: cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { if (opr->same_type()) { return; - } - - if (opr->same_type()) { + } else if (opr->same_type()) { auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); if (imm.valid()) { auto dtype = imm->dtype(); @@ -150,59 +149,53 @@ private: "dtype, but got %s", dtype.name()); } - auto&& out = m_builder.create( + auto&& out = m_builder.create( m_builder.getUnknownLoc(), m_builder.getF32Type(), m_builder.getF32FloatAttr(scalar_value)); mgb_assert(mlir::succeeded( declare(opr->output(0)->name(), out))); } - } - - if (opr->same_type()) { - auto&& out = gen_op(opr->cast_final()); + } else if (opr->same_type()) { + auto&& out = gen_elemwise(opr->cast_final()); + mgb_assert( + mlir::succeeded(declare(opr->output(0)->name(), out))); + return; + } else if (opr->same_type()) { + auto&& out = gen_typecvt(opr->cast_final()); mgb_assert( mlir::succeeded(declare(opr->output(0)->name(), out))); } }} .add(internal_graph.output()); - m_builder.create(m_builder.getUnknownLoc(), - get(internal_graph.output()), - get(args.outputs[0].from)); + m_builder.create(m_builder.getUnknownLoc(), + get(internal_graph.output()), + get(args.outputs[0].from)); return mlir::success(); } - mlir::Value gen_op(const opr::Elemwise& opr) { - switch (opr.param().mode) { -#define cb(mlir_op, mgb_mode) \ - case opr::Elemwise::Mode::mgb_mode: \ - return m_builder.create(m_builder.getUnknownLoc(), \ - get(opr.input(0)), \ - get(opr.input(1))); \ - break; - MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) -#undef cb - -#define cb(mlir_op, mgb_mode) \ - case opr::Elemwise::Mode::mgb_mode: \ - return m_builder.create(m_builder.getUnknownLoc(), \ - get(opr.input(0))); \ - break; - MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) -#undef cb -#define cb(mlir_op, mgb_mode) \ - case opr::Elemwise::Mode::mgb_mode: \ - return m_builder.create( \ - m_builder.getUnknownLoc(), get(opr.input(0)), \ - get(opr.input(1)), get(opr.input(2))); \ - break; - MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) -#undef cb - - default: - return nullptr; + mlir::Value gen_elemwise(const opr::Elemwise& opr) { + llvm::SmallVector operands; + for (size_t i = 0; i < opr.input().size(); i++) { + operands.push_back(get(opr.input(i))); } - return nullptr; + mlir::Type res_type = deduce_elemwise_res_type(operands); + return m_builder.create( + m_builder.getUnknownLoc(), res_type, mlir::ValueRange(operands), + opr.param().mode); + } + + mlir::Value gen_typecvt(const opr::TypeCvt& opr) { + auto shape = get(opr.input(0)) + .getType() + .dyn_cast_or_null() + .getShape(); + auto res_type = mlir::MemRefType::get( + shape, + megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); + return m_builder.create( + m_builder.getUnknownLoc(), res_type, get(opr.input(0)), + opr.input(0)->dtype(), opr.param()); } mlir::Type get_type(const TensorLayout& layout) { diff --git a/src/jit/include/megbrain/jit/mlir/ir/CMakeLists.txt b/src/jit/include/megbrain/jit/mlir/ir/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..103aeef6491892e7640e700a97b72fd78f5ca420 --- /dev/null +++ b/src/jit/include/megbrain/jit/mlir/ir/CMakeLists.txt @@ -0,0 +1,6 @@ +# mgb_dialect +set(LLVM_TARGET_DEFINITIONS mgb_dialect.td) +tablegen(MLIR mgb_dialect.h.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-decls") +tablegen(MLIR mgb_dialect.cpp.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-defs") +add_custom_target(mgb_dialect DEPENDS mgb_dialect.h.inc mgb_dialect.cpp.inc) +add_dependencies(mgb_dialect param_defs_tblgen) diff --git a/src/jit/include/megbrain/jit/mlir/ir/dialect.h b/src/jit/include/megbrain/jit/mlir/ir/dialect.h index d1c93e9b0bec61fbb92e826f6c5478d7beb5f2ed..4ce39f895b94722f88ae5aeded2f63b246168d95 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/dialect.h +++ b/src/jit/include/megbrain/jit/mlir/ir/dialect.h @@ -1,5 +1,5 @@ /** - * \file src/jit/impl/mlir/ir/dialect.h + * \file src/jit/include/megbrain/jit/mlir/ir/dialect.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -15,8 +15,7 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR -#include "megbrain/jit/mlir/ir/interfaces.h" -#include "megbrain/jit/mlir/ir/utils.h" +#include "megdnn/opr_param_defs.h" #include #include @@ -39,7 +38,7 @@ public: #define GET_OP_CLASSES using namespace mlir; -#include "megbrain/jit/mlir/ir/ops.h.inc" +#include "megbrain/jit/mlir/ir/mgb_dialect.h.inc" #endif // MGB_JIT && MGB_JIT_MLIR diff --git a/src/jit/include/megbrain/jit/mlir/ir/interfaces.h b/src/jit/include/megbrain/jit/mlir/ir/interfaces.h deleted file mode 100644 index 4803bc0d1f5e79a4b6a308de0d7fd1472a9c6d54..0000000000000000000000000000000000000000 --- a/src/jit/include/megbrain/jit/mlir/ir/interfaces.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * \file src/jit/include/mlir/ir/interfaces.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#pragma once - -#include "megbrain_build_config.h" -#if MGB_JIT_MLIR - -#include -#include - -namespace mlir { -/// Include the auto-generated declarations. -#include "megbrain/jit/mlir/ir/interfaces.h.inc" -} - -#endif // MGB_JIT_MLIR - -// vim: syntax=cpp.doxygen diff --git a/src/jit/include/megbrain/jit/mlir/ir/passes.h b/src/jit/include/megbrain/jit/mlir/ir/passes.h index 630554bda018db8a8d635c6c074901d9ce2c3b6b..4acd8fb13f8bac6af17b8e1340e76d1635cdb180 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/passes.h +++ b/src/jit/include/megbrain/jit/mlir/ir/passes.h @@ -1,5 +1,5 @@ /** - * \file src/jit/impl/mlir/ir/passes.h + * \file src/jit/include/megbrain/jit/mlir/ir/passes.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -11,8 +11,8 @@ */ #pragma once -#include "megbrain_build_config.h" +#include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include diff --git a/src/jit/include/megbrain/jit/mlir/ir/utils.h b/src/jit/include/megbrain/jit/mlir/ir/utils.h index 710dd57e63bf0b91bfe4f5f224d47722e186b5ca..a65f102962de46997ecbda05aa8596a1c5a6d723 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/utils.h +++ b/src/jit/include/megbrain/jit/mlir/ir/utils.h @@ -1,5 +1,5 @@ /** - * \file src/jit/include/megbrain/mlir/ir/utils.h + * \file src/jit/include/megbrain/jit/mlir/ir/utils.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -35,15 +35,19 @@ std::string mlir_type_to_string(T&& t) { mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, mlir::PatternRewriter& rewriter); -mlir::Type deduce_result_type(mlir::ValueRange operands); +mlir::Type deduce_elemwise_res_type(mlir::ValueRange operands); /** - * \brief convert mlir type to TensorShape + * \brief convert MLIR Type to TensorLayout */ megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); -megdnn::DType mlir_type_to_dtype(mlir::Type type); + +/** + * \brief convert TensorLayout to MLIR Type + */ mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, mlir::Builder& builder); + } // namespace jit } // namespace mgb diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 0d3aaeb1137d489e84459997157031cf48167f4e..cffe0dd622c1a3e6754f28131752799b42641d5e 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -267,6 +267,8 @@ void run_mlir_mode(CompNode cn) { } // anonymous namespace +/* ===================== TestJITHalideCodeGenCude ===================== */ + #if MGB_JIT_HALIDE template class TestJITHalideCodeGenCuda : public ::testing::Test {}; @@ -277,6 +279,8 @@ TYPED_TEST(TestJITHalideCodeGenCuda, run) { } #endif +/* ===================== TestJITNvrtcCodeGen ===================== */ + template class TestJITNvrtcCodeGen : public ::testing::Test {}; TYPED_TEST_CASE(TestJITNvrtcCodeGen, test_types); @@ -285,6 +289,8 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { run(Backend::NVRTC, CompNode::load("gpu0")); } +/* ===================== TestJITMlirCodeGen ===================== */ + #if MGB_JIT_MLIR TEST(TestJITMlirCodeGen, Basic) { auto cn = CompNode::load("cpu0"); @@ -299,7 +305,8 @@ TEST(TestJITMlirCodeGen, BasicGPU) { run_mlir_broadcast(cn); } -///////////////////////// unary /////////////////////////////// +/* ===================== TestJITMlirUnaryElemwise ===================== */ + // clang-format off #define FOREACH_UNARY_MODE(cb) \ cb(RELU) \ @@ -365,7 +372,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { run_mlir_mode(cn); } -///////////////////////// binary /////////////////////////////// +/* ===================== TestJITMlirBinaryElemwise ===================== */ + // clang-format off #define FOREACH_BINARY_MODE(cb) \ cb(ADD) \ @@ -422,7 +430,8 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { run_mlir_mode(cn); } -///////////////////////// ternary /////////////////////////////// +/* ===================== TestJITMlirTenaryElemwise ===================== */ + // clang-format off #define FOREACH_TERNARY_MODE(cb) \ cb(COND_LEQ_MOV) \ @@ -456,6 +465,81 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { #undef SKIP_MODE + +/* ===================== TestJITMlirTypeCvt ===================== */ + +template +void run_typecvt(CompNode cn) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + HostTensorGenerator gen(-10, 10); + + auto host_x = gen({23, 42}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto y = opr::TypeCvt::make(x, otype()); + + auto ig_gen = std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->template same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); +}; + +#define add_typecvt_gtest(itype, otype) \ + TEST(TestJITMlirTypeCvt, itype##_to_##otype) { \ + run_typecvt(CompNode::load("cpu0")); \ + } \ + TEST(TestJITMlirTypeCvt, itype##_to_##otype##_GPU) { \ + REQUIRE_GPU(1); \ + run_typecvt(CompNode::load("gpu0")); \ + } + +#if !MEGDNN_DISABLE_FLOAT16 + +// TODO: the support for f16 and bf16 is currently not complete in mlir + +// FPExtOp +// add_typecvt_gtest(Float16, Float32); +// add_typecvt_gtest(BFloat16, Float32); +// add_typecvt_gtest(Float16, BFloat16); + +// FPTruncOp +// add_typecvt_gtest(Float32, Float16); +// add_typecvt_gtest(Float32, BFloat16); +// add_typecvt_gtest(Float16, BFloat16); + +#endif + +// FPToSIOp +add_typecvt_gtest(Float32, Int8); +add_typecvt_gtest(Float32, Int16); +add_typecvt_gtest(Float32, Int32); + +// FPToUIOp +add_typecvt_gtest(Float32, Uint8); + +// SIToFPOp +add_typecvt_gtest(Int8, Float32); +add_typecvt_gtest(Int16, Float32); +add_typecvt_gtest(Int32, Float32); + +// UIToFPOp +add_typecvt_gtest(Uint8, Float32); + +#undef add_typecvt_gtest + #endif // MGB_JIT_MLIR #endif // MGB_JIT diff --git a/src/jit/test/mlir/ir/add.mlir b/src/jit/test/mlir/ir/add.mlir index 6966b083c181fa06a119bf651232b8ec60087a83..d0edb68dd4a490cfce7bcbdf31e8cf77e1e64c23 100644 --- a/src/jit/test/mlir/ir/add.mlir +++ b/src/jit/test/mlir/ir/add.mlir @@ -2,7 +2,7 @@ // RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () { - %0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : + %0 = "mgb.Elemwise"(%lhs, %rhs) {name = "add.f", mode = 16 : i32} : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> "mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () mgb.return @@ -24,7 +24,7 @@ func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> // CHECK: } func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () { - %0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : + %0 = "mgb.Elemwise"(%lhs, %rhs) {name = "add.f", mode = 16 : i32} : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32> "mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> () mgb.return @@ -55,4 +55,4 @@ func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: m // CHECK: } // CHECK: dealloc %0 : memref<4x3x64x64xf32> // CHECK: return -// CHECK: } \ No newline at end of file +// CHECK: }