提交 404ef808 编写于 作者: M Megvii Engine Team

feat(mgb/jit): adapt jit mlir backend to new mgb dialect and add typecvt

GitOrigin-RevId: bd1b80c84f5629a4dde5302d775cd8a333ea3cf2
上级 cc85047b
...@@ -753,12 +753,14 @@ install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) ...@@ -753,12 +753,14 @@ install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS})
if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR)
# generate param_defs.td # generate param_defs.td
set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles)
set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir)
set(OPR_PARAM_DEFS_SRCS ${MGE_GENFILE_DIR}/opr_param_defs.py) set(OPR_PARAM_DEFS_SRCS ${MGE_GENFILE_DIR}/opr_param_defs.py)
set(OPR_PARAM_DEFS_SCRIPT ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_tablegen.py) set(OPR_PARAM_DEFS_SCRIPT ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_tablegen.py)
set(OPR_PARAM_DEFS_OUT ${MGE_GENFILE_DIR}/param_defs.td) set(OPR_PARAM_DEFS_OUT ${MGE_GEN_IR_DIR}/param_defs.td)
file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${MGE_GENFILE_DIR}) file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${MGE_GENFILE_DIR})
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS)
file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS}) file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS})
file(MAKE_DIRECTORY ${MGE_GEN_IR_DIR})
add_custom_target(param_defs_tblgen add_custom_target(param_defs_tblgen
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT} COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT}
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT}
...@@ -766,7 +768,7 @@ if(MGE_WITH_JIT_MLIR) ...@@ -766,7 +768,7 @@ if(MGE_WITH_JIT_MLIR)
) )
# mlir tblgen sources # mlir tblgen sources
set(MGE_IR_DIR ${PROJECT_SOURCE_DIR}/src/core/include/megbrain/ir) set(MGE_IR_DIR ${PROJECT_SOURCE_DIR}/src/core/include/megbrain/ir)
set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_GENFILE_DIR} ${MGE_IR_DIR}) set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_IR_DIR} ${MGE_GEN_IR_DIR})
list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I") list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I")
file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td) file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td)
endif() endif()
......
if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR)
add_subdirectory(jit/impl/mlir/ir) add_subdirectory(jit/include/megbrain/jit/mlir/ir)
endif() endif()
file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp opr/impl/nvof/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl) file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp opr/impl/nvof/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl)
...@@ -100,9 +100,10 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) ...@@ -100,9 +100,10 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE)
target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS}) target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS})
endif() endif()
if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR)
target_link_libraries(megbrain PRIVATE mlir_op_def) target_include_directories(megbrain PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(megbrain PRIVATE mlir_shape_inference)
target_link_libraries(megbrain PRIVATE ${MLIR_LLVM_LIBS}) target_link_libraries(megbrain PRIVATE ${MLIR_LLVM_LIBS})
add_dependencies(megbrain mgb_dialect)
target_include_directories(megbrain PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/include)
endif() endif()
if (MGB_WITH_FLATBUFFERS) if (MGB_WITH_FLATBUFFERS)
set (GEN_FLATBUFFERS_SCHEMA_PY ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_flatbuffers_schema.py) set (GEN_FLATBUFFERS_SCHEMA_PY ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_flatbuffers_schema.py)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "./executable_cpu.h" #include "./executable_cpu.h"
#include "./executable_cuda.h" #include "./executable_cuda.h"
#include "./mlir_gen.h" #include "./mlir_gen.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"
......
...@@ -14,37 +14,44 @@ ...@@ -14,37 +14,44 @@
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "./executable_cpu.h" #include "./executable_cpu.h"
#include "./ir/types.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/ExecutionEngine/CRunnerUtils.h> #include <mlir/ExecutionEngine/CRunnerUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h>
using namespace mgb; using namespace mgb;
using namespace jit; using namespace jit;
namespace { namespace {
template <typename T, int N>
StridedMemRefType<T, N>* get_strided_memref_type(
const megdnn::TensorND& tensor) {
using DescType = StridedMemRefType<T, N>;
DescType* desc = static_cast<DescType*>(malloc(sizeof(DescType)));
desc->basePtr = tensor.ptr<T>();
desc->data = tensor.ptr<T>();
desc->offset = 0;
for (size_t i = 0; i < tensor.layout.ndim; i++) {
desc->sizes[i] = tensor.layout.shape[i];
desc->strides[i] = tensor.layout.stride[i];
}
return desc;
}
template <int N> template <int N>
void* tensor2memref_dim(const megdnn::TensorND& tensor) { void* tensor2memref_dim(const megdnn::TensorND& tensor) {
switch (tensor.layout.dtype.enumv()) { switch (tensor.layout.dtype.enumv()) {
case megdnn::DTypeEnum::Float32: { #define cb(_dtype, _type) \
StridedMemRefType<float, N>* desc = case megdnn::DTypeEnum::_dtype: \
static_cast<StridedMemRefType<float, N>*>( return get_strided_memref_type<_type, N>(tensor);
malloc(sizeof(StridedMemRefType<float, N>))); FOR_EACH_DNN_DTYPE(cb)
desc->basePtr = tensor.ptr<float>(); #undef cb
desc->data = tensor.ptr<float>();
desc->offset = 0;
for (size_t i = 0; i < tensor.layout.ndim; i++) {
desc->sizes[i] = tensor.layout.shape[i];
desc->strides[i] = tensor.layout.stride[i];
}
return desc;
break;
}
default: default:
mgb_throw(InternalError, "Unsupport dtype, got %s", mgb_throw(InternalError, "Unsupported dtype: %s",
tensor.layout.dtype.name()); tensor.layout.dtype.name());
break;
} }
return nullptr; return nullptr;
} }
......
...@@ -10,18 +10,18 @@ ...@@ -10,18 +10,18 @@
* implied. * implied.
*/ */
#include <vector>
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#include "megdnn/dtype.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#if MGB_CUDA #if MGB_CUDA
#include "./executable_cuda.h" #include "./executable_cuda.h"
#include "./ir/types.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
#include "megbrain/utils/persistent_cache.h" #include "megbrain/utils/persistent_cache.h"
#include "megbrain/utils/timer.h" #include "megbrain/utils/timer.h"
#include "megdnn/dtype.h"
#include <mlir/Dialect/GPU/GPUDialect.h> #include <mlir/Dialect/GPU/GPUDialect.h>
#include <mlir/ExecutionEngine/CRunnerUtils.h> #include <mlir/ExecutionEngine/CRunnerUtils.h>
...@@ -83,6 +83,24 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, ...@@ -83,6 +83,24 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0,
env.cuda_env().stream, params.data(), 0)); env.cuda_env().stream, params.data(), 0));
} }
template <int out_dim>
void setup_and_launch_dim(const megdnn::DType dtype,
const JITExecutor* fusion_opr, CUfunction func,
int block_size) {
switch (dtype.enumv()) {
#define cb(_dtype, _type) \
case megdnn::DTypeEnum::_dtype: \
setup_and_launch<out_dim, _type>(fusion_opr, func, block_size); \
return;
FOR_EACH_DNN_DTYPE(cb)
#undef cb
default:
mgb_throw(InternalError, "Unsupported dtype: %s", dtype.name());
}
return;
}
} // namespace } // namespace
const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin"; const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin";
...@@ -136,30 +154,19 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, ...@@ -136,30 +154,19 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr,
fusion_opr->args().outputs.size()); fusion_opr->args().outputs.size());
int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; int out_dim = fusion_opr->args().outputs[0].from->layout().ndim;
DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; DType dtype = fusion_opr->args().outputs[0].from->layout().dtype;
#define cb_outdim(_ndim, _dtype) \
if (_ndim == out_dim) { \
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \
func->block_size); \
return; \
}
#define cb(_dtype) \
cb_outdim(1, float); \
cb_outdim(2, float); \
cb_outdim(3, float); \
cb_outdim(4, float); \
mgb_throw(InternalError, "unsupported out_dim=%zu", \
static_cast<size_t>(out_dim)); \
return;
switch (dtype.enumv()) { switch (out_dim) {
case DTypeEnum::Float32: #define cb(_ndim) \
cb(float); case _ndim: \
default: setup_and_launch_dim<_ndim>(dtype, fusion_opr, func->func, \
mgb_throw(InternalError, "unsupport dtype: %s", dtype.name()); func->block_size); \
} break;
cb(1);
cb(2);
cb(3);
cb(4);
#undef cb #undef cb
#undef cb_outdim }
} }
#endif // MGB_CUDA #endif // MGB_CUDA
......
set(MGB_MLIR_TABLEGEN_INC_BASE ${CMAKE_CURRENT_BINARY_DIR}/include/)
file(MAKE_DIRECTORY ${MGB_MLIR_TABLEGEN_INC_BASE}/megbrain/jit/mlir/ir/)
list(APPEND MGB_MLIR_TABLEGEN_INC ${MGB_MLIR_TABLEGEN_INC_BASE})
external_tablegen_library(
NAME
mlir_shape_inference
TBLGEN
MLIR
SRCS
"interfaces.td"
INCLUDES
${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR}
OUTS
-gen-op-interface-decls include/megbrain/jit/mlir/ir/interfaces.h.inc
-gen-op-interface-defs include/megbrain/jit/mlir/ir/interfaces.cpp.inc
)
external_tablegen_library(
NAME
mlir_op_def
TBLGEN
MLIR
SRCS
"ops.td"
INCLUDES
${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR}
OUTS
-gen-op-decls include/megbrain/jit/mlir/ir/ops.h.inc
-gen-op-defs include/megbrain/jit/mlir/ir/ops.cpp.inc
)
# mgb_dialect
set(MGB_DIALECT_TD ${PROJECT_SOURCE_DIR}/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td)
set(LLVM_TARGET_DEFINITIONS ${MGB_DIALECT_TD})
tablegen(MLIR mgb_dialect.h.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-decls")
tablegen(MLIR mgb_dialect.cpp.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-defs")
add_custom_target(mgb_dialect DEPENDS mgb_dialect.h.inc mgb_dialect.cpp.inc ${MGB_DIALECT_TD} ${MGE_IR_TDS})
add_dependencies(mgb_dialect param_defs_tblgen)
...@@ -14,91 +14,99 @@ ...@@ -14,91 +14,99 @@
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "./common.h" #include "./common.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include <mlir/Dialect/Affine/IR/AffineOps.h> #include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
using namespace mgb; using namespace mgb;
using namespace jit; using namespace jit;
/* ===================== trivial unary functions ===================== */
#define cb(name, op) \
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \
return m_builder.create<mlir::op>(m_location, lhs); \
}
cb(abs, AbsFOp);
cb(ceil, CeilFOp);
cb(cos, CosOp);
cb(exp2, Exp2Op);
cb(exp, ExpOp);
cb(floor, FloorFOp);
cb(log10, Log10Op);
cb(log2, Log2Op);
cb(log, LogOp);
cb(neg, NegFOp);
cb(rsqrt, RsqrtOp);
cb(sin, SinOp);
cb(sqrt, SqrtOp);
cb(tanh, TanhOp);
#undef cb
/* ===================== trivial binary functions ===================== */
#define cb(name, op) \ #define cb(name, op) \
mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \
return m_builder.create<mlir::op>(m_location, lhs, rhs); \ return m_builder.create<mlir::op>(m_location, lhs, rhs); \
} }
cb(add, AddFOp); cb(add, AddFOp);
cb(sub, SubFOp);
cb(mul, MulFOp);
cb(div, DivFOp);
cb(divI, SignedDivIOp);
cb(mod, RemFOp);
cb(bit_and, AndOp); cb(bit_and, AndOp);
cb(bit_or, OrOp); cb(bit_or, OrOp);
cb(div, DivFOp);
cb(divI, SignedDivIOp);
cb(modI, SignedRemIOp); cb(modI, SignedRemIOp);
cb(mod, RemFOp);
cb(mul, MulFOp);
cb(sub, SubFOp);
#undef cb #undef cb
/* ===================== compare functions ===================== */
#define cb(name, mode) \ #define cb(name, mode) \
mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \
return m_builder.create<mlir::CmpFOp>( \ return m_builder.create<mlir::CmpFOp>( \
m_location, mlir::CmpFPredicate::mode, lhs, rhs); \ m_location, mlir::CmpFPredicate::mode, lhs, rhs); \
} }
cb(gt, OGT);
cb(eq, OEQ);
cb(ge, OGE); cb(ge, OGE);
cb(lt, OLT); cb(gt, OGT);
cb(le, OLE); cb(le, OLE);
cb(eq, OEQ); cb(lt, OLT);
#undef cb #undef cb
mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) {
mlir::Value cmp = m_builder.create<mlir::CmpFOp>( mlir::Value cmp = m_builder.create<mlir::CmpFOp>(
m_location, mlir::CmpFPredicate::OLT, lhs, rhs); m_location, mlir::CmpFPredicate::OGT, lhs, rhs);
return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs);
} }
mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) {
mlir::Value cmp = m_builder.create<mlir::CmpFOp>( mlir::Value cmp = m_builder.create<mlir::CmpFOp>(
m_location, mlir::CmpFPredicate::OGT, lhs, rhs); m_location, mlir::CmpFPredicate::OLT, lhs, rhs);
return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs);
} }
mlir::Value ValueBuilderHelper::const_val(float val) { /* ===================== constant functions ===================== */
mlir::Value ValueBuilderHelper::const_f32(float val) {
return m_builder.create<mlir::ConstantOp>(m_location, return m_builder.create<mlir::ConstantOp>(m_location,
m_builder.getF32FloatAttr(val)); m_builder.getF32FloatAttr(val));
} }
mlir::Value ValueBuilderHelper::constI(int32_t val) { mlir::Value ValueBuilderHelper::const_i32(int32_t val) {
return m_builder.create<mlir::ConstantOp>(m_location, return m_builder.create<mlir::ConstantOp>(m_location,
m_builder.getIndexAttr(val)); m_builder.getIndexAttr(val));
} }
#define cb(name, op) \ /* ===================== select function ===================== */
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \
return m_builder.create<mlir::op>(m_location, lhs); \
}
cb(neg, NegFOp);
cb(ceil, CeilFOp);
cb(cos, CosOp);
cb(exp, ExpOp);
cb(exp2, Exp2Op);
cb(log10, Log10Op);
cb(log2, Log2Op);
cb(log, LogOp);
cb(rsqrt, RsqrtOp);
cb(sin, SinOp);
cb(sqrt, SqrtOp);
cb(tanh, TanhOp);
#undef cb
mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) {
auto zero = const_val(0.f);
return select(ge(lhs, zero), lhs, sub(zero, lhs));
}
mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) {
//! FIXME use standard floor when upgrade llvm
return neg(ceil(neg(lhs)));
}
mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val,
mlir::Value false_val) { mlir::Value false_val) {
...@@ -106,6 +114,8 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, ...@@ -106,6 +114,8 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val,
false_val); false_val);
} }
/* ===================== helper functions ===================== */
mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder,
const mlir::Value& val, const mlir::Value& val,
const megdnn::TensorLayout& layout) { const megdnn::TensorLayout& layout) {
...@@ -125,10 +135,10 @@ mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, ...@@ -125,10 +135,10 @@ mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder,
} }
mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder,
const mlir::Location& loc, const mlir::Location& loc,
const mlir::Value& val, const mlir::Value& val,
const mlir::ValueRange& index, const mlir::ValueRange& index,
const megdnn::TensorLayout& dst) { const megdnn::TensorLayout& dst) {
if (val.getType().isa<mlir::MemRefType>()) { if (val.getType().isa<mlir::MemRefType>()) {
auto type = val.getType().cast<mlir::MemRefType>(); auto type = val.getType().cast<mlir::MemRefType>();
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <mlir/IR/Value.h> #include <mlir/IR/Value.h>
...@@ -30,50 +32,59 @@ public: ...@@ -30,50 +32,59 @@ public:
ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location) ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location)
: m_builder{b}, m_location{location} {}; : m_builder{b}, m_location{location} {};
#define cb(name) \
mlir::Value name(mlir::ValueRange operands) { \
return name(operands[0], operands[1]); \
} \
mlir::Value name(mlir::Value lhs, mlir::Value rhs)
cb(add);
cb(sub);
cb(mul);
cb(div);
cb(divI);
cb(max);
cb(min);
cb(mod);
cb(modI);
cb(gt);
cb(ge);
cb(lt);
cb(le);
cb(eq);
cb(bit_and);
cb(bit_or);
#undef cb
mlir::Value const_val(float val);
mlir::Value constI(int32_t val);
#define cb(name) \ #define cb(name) \
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \
mlir::Value name(mlir::Value lhs) mlir::Value name(mlir::Value lhs)
cb(neg);
// unary functions
cb(abs); cb(abs);
cb(ceil); cb(ceil);
cb(floor);
cb(cos); cb(cos);
cb(exp); cb(exp);
cb(exp2); cb(exp2);
cb(floor);
cb(log);
cb(log10); cb(log10);
cb(log2); cb(log2);
cb(log); cb(neg);
cb(rsqrt); cb(rsqrt);
cb(sin); cb(sin);
cb(sqrt); cb(sqrt);
cb(tanh); cb(tanh);
#undef cb #undef cb
#define cb(name) \
mlir::Value name(mlir::ValueRange operands) { \
return name(operands[0], operands[1]); \
} \
mlir::Value name(mlir::Value lhs, mlir::Value rhs)
// binary functions
cb(add);
cb(bit_and);
cb(bit_or);
cb(div);
cb(divI);
cb(eq);
cb(ge);
cb(gt);
cb(le);
cb(lt);
cb(max);
cb(min);
cb(mod);
cb(modI);
cb(mul);
cb(sub);
#undef cb
// constant functions
mlir::Value const_f32(float val);
mlir::Value const_i32(int32_t val);
// select function
mlir::Value select(mlir::Value cond, mlir::Value true_val, mlir::Value select(mlir::Value cond, mlir::Value true_val,
mlir::Value false_val); mlir::Value false_val);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"
#include "./types.h" #include "./types.h"
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
...@@ -28,14 +29,12 @@ MgbDialect::MgbDialect(mlir::MLIRContext* ctx) ...@@ -28,14 +29,12 @@ MgbDialect::MgbDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("mgb", ctx, mlir::TypeID::get<MgbDialect>()) { : mlir::Dialect("mgb", ctx, mlir::TypeID::get<MgbDialect>()) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "megbrain/jit/mlir/ir/ops.cpp.inc" #include "megbrain/jit/mlir/ir/mgb_dialect.cpp.inc"
>(); >();
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "megbrain/jit/mlir/ir/ops.cpp.inc" #include "megbrain/jit/mlir/ir/mgb_dialect.cpp.inc"
#include "megbrain/jit/mlir/ir/interfaces.cpp.inc"
#endif // MGB_JIT && MGB_JIT_MLIR #endif // MGB_JIT && MGB_JIT_MLIR
......
/**
* \file src/jit/impl/mlir/ir/each_mode.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "./each_mode.h"
#include "./numerical.h"
#include "./types.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
namespace mgb {
namespace jit {
using Mode = megdnn::param::Elemwise::Mode;
template <Mode mode>
mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands);
/* ===================== trivial implementations ===================== */
#define cb(mode, fun) \
template <> \
mlir::Value lower_mode<Mode::mode>(mlir::OpBuilder & builder, \
mlir::Location loc, \
ValueRange operands) { \
ValueBuilderHelper helper(builder, loc); \
return helper.fun(operands); \
}
//! unary
cb(ABS, abs);
cb(CEIL, ceil);
cb(COS, cos);
cb(EXP, exp);
cb(FLOOR, floor);
cb(LOG, log);
cb(NEGATE, neg);
cb(SIN, sin);
cb(TANH, tanh);
//! binary
cb(ADD, add);
cb(MAX, max);
cb(MIN, min);
cb(MOD, mod);
cb(MUL, mul);
cb(SUB, sub);
cb(TRUE_DIV, div);
#undef cb
/* ===================== unary op ===================== */
//! ACOS: pi / 2 - arctan2(x, sqrt(1 - x * x))
template <>
mlir::Value lower_mode<Mode::ACOS>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x = operands[0];
auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
auto pi_over_2 = helper.const_f32(1.57079637f);
return helper.sub(pi_over_2, asin);
}
//! ASIN: arctan2(x, sqrt(1 - x * x))
template <>
mlir::Value lower_mode<Mode::ASIN>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x = operands[0];
auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
return atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
}
//! ERFCINV: inverse of complementary gauss error function
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
template <>
mlir::Value lower_mode<Mode::ERFCINV>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto minus_sqrt2 = helper.const_f32(-1.4142135623f);
auto x = helper.mul(helper.const_f32(0.5f), operands[0]);
return helper.div(ndtri_approx(helper, x), minus_sqrt2);
}
//! ERFC: complementary error function
template <>
mlir::Value lower_mode<Mode::ERFC>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0]));
}
//! ERFINV: inverse of gauss error function
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
template <>
mlir::Value lower_mode<Mode::ERFINV>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sqrt2 = helper.const_f32(1.4142135623f);
auto x = helper.mul(helper.const_f32(0.5f),
helper.add(operands[0], helper.const_f32(1.f)));
return helper.div(ndtri_approx(helper, x), sqrt2);
}
//! ERF: gauss error function
template <>
mlir::Value lower_mode<Mode::ERF>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return erf_approx(helper, operands[0]);
}
//! EXPM1: exp(x) - 1
template <>
mlir::Value lower_mode<Mode::EXPM1>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f));
}
//! FAST_TANH: x * (27.f + x * x) / (27.f + 9.f * x * x);
template <>
mlir::Value lower_mode<Mode::FAST_TANH>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto square = helper.mul(operands[0], operands[0]);
return helper.div(
helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)),
helper.add(helper.const_f32(27.f),
helper.mul(helper.const_f32(9.f), square)));
}
//! H_SWISH: x * clip(x + 3, 0, 6) / 6
template <>
mlir::Value lower_mode<Mode::H_SWISH>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto const_3 = helper.const_f32(3.f);
auto const_0 = helper.const_f32(0.f);
auto const_6 = helper.const_f32(6.f);
auto tmp = helper.add(operands[0], const_3);
return helper.div(helper.mul(operands[0],
helper.min(helper.max(tmp, const_0), const_6)),
const_6);
}
//! LOG1P: log(1 + p)
template <>
mlir::Value lower_mode<Mode::LOG1P>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.log(helper.add(operands[0], helper.const_f32(1.f)));
}
//! RELU: max(x, 0)
template <>
mlir::Value lower_mode<Mode::RELU>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.max(operands[0], helper.const_f32(0.f));
}
//! ROUND
template <>
mlir::Value lower_mode<Mode::ROUND>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.gt(operands[0], helper.const_f32(0.f)),
helper.floor(helper.add(operands[0], helper.const_f32(0.5f))),
helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f))));
}
//! SIGMOID: 1.f / (expf(-y) + 1.f))
template <>
mlir::Value lower_mode<Mode::SIGMOID>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.div(helper.const_f32(1.f),
helper.add(helper.exp(helper.neg(operands[0])),
helper.const_f32(1.f)));
}
/* ===================== binary op ===================== */
//! ABS_GRAD: x > 0 ? y : -y
template <>
mlir::Value lower_mode<Mode::ABS_GRAD>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.gt(operands[0], helper.const_f32(0.f)),
operands[1], helper.neg(operands[1]));
}
//! ATAN2
template <>
mlir::Value lower_mode<Mode::ATAN2>(mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return atan2_approx(helper, operands[0], operands[1]);
}
//! EQ: x == y ? 1 : 0
template <>
mlir::Value lower_mode<Mode::EQ>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.eq(operands[0], operands[1]),
helper.const_f32(1.f), helper.const_f32(0.f));
}
//! FAST_TANH_GRAD: ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x
//! * x) * y
template <>
mlir::Value lower_mode<Mode::FAST_TANH_GRAD>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x_pow2 = helper.mul(operands[0], operands[0]);
auto deno = helper.add(helper.const_f32(3.f), x_pow2);
return helper.mul(
helper.div(
helper.add(
helper.add(helper.div(helper.mul(helper.const_f32(
-48.f),
x_pow2),
deno),
helper.const_f32(27.f)),
x_pow2),
helper.mul(deno, helper.const_f32(9.f))),
operands[1]);
}
//! FLOOR_DIV: floor(x/y)
template <>
mlir::Value lower_mode<Mode::FLOOR_DIV>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.floor(helper.div(operands[0], operands[1]));
}
//! FUSE_ADD_H_SWISH: (x+y) * min(max(x + y + 3, 0), 6) * (1/6)
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_H_SWISH>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sum = helper.add(operands[0], operands[1]);
auto const_3 = helper.const_f32(3.f);
auto const_0 = helper.const_f32(0.f);
auto const_6 = helper.const_f32(6.f);
auto tmp = helper.add(sum, const_3);
return helper.div(
helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)),
const_6);
}
//! FUSE_ADD_RELU: (x + y) <= ctype(0) ? ctype(0) : (x + y)
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_RELU>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sum = helper.add(operands[0], operands[1]);
return helper.max(sum, helper.const_f32(0.f));
}
//! FUSE_ADD_SIGMOID: 1.f / (expf(-(x+y)) + 1.f))
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_SIGMOID>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.div(helper.const_f32(1.f),
helper.add(helper.exp(helper.neg(
helper.add(operands[0], operands[1]))),
helper.const_f32(1.f)));
}
//! FUSE_ADD_TANH: tanh(x + y)
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_TANH>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.tanh(helper.add(operands[0], operands[1]));
}
//! H_SWISH_GRAD: x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y)
template <>
mlir::Value lower_mode<Mode::H_SWISH_GRAD>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.lt(operands[0], helper.const_f32(-3.f)),
helper.const_f32(0.f),
helper.select(
helper.gt(operands[0], helper.const_f32(3.f)), operands[1],
helper.mul(
helper.div(
helper.add(helper.mul(helper.const_f32(2.f),
operands[0]),
helper.const_f32(3.f)),
helper.const_f32(6.f)),
operands[1])));
}
//! LEQ: x <= y ? 1 : 0
template <>
mlir::Value lower_mode<Mode::LEQ>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.le(operands[0], operands[1]),
helper.const_f32(1.f), helper.const_f32(0.f));
}
//! LOG_SUM_EXP: log(exp(x) + exp(y))
template <>
mlir::Value lower_mode<Mode::LOG_SUM_EXP>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.log(
helper.add(helper.exp(operands[0]), helper.exp(operands[1])));
}
//! LT: x < y ? 1 : 0
template <>
mlir::Value lower_mode<Mode::LT>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.lt(operands[0], operands[1]),
helper.const_f32(1.f), helper.const_f32(0.f));
}
//! POW: x^y = exp(y * log(x))
template <>
mlir::Value lower_mode<Mode::POW>(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.exp(helper.mul(operands[1], helper.log(operands[0])));
}
//! SIGMOID_GRAD: x * (1 - x) * y
template <>
mlir::Value lower_mode<Mode::SIGMOID_GRAD>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.mul(helper.mul(operands[0], helper.sub(helper.const_f32(1.f),
operands[0])),
operands[1]);
}
//! SWITCH_GT0: (x > 0) * y
template <>
mlir::Value lower_mode<Mode::SWITCH_GT0>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.gt(operands[0], helper.const_f32(0.f)),
operands[1], helper.const_f32(0.f));
}
//! TANH_GRAD: (1 - x * x) * y
template <>
mlir::Value lower_mode<Mode::TANH_GRAD>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.mul(helper.sub(helper.const_f32(1.0f),
helper.mul(operands[0], operands[0])),
operands[1]);
}
/* ===================== ternary op ===================== */
//! COND_LEQ_MOV: x <= y ? z : ctype(0)
template <>
mlir::Value lower_mode<Mode::COND_LEQ_MOV>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.le(operands[0], operands[1]), operands[2],
helper.const_f32(0.f));
}
//! FUSE_MUL_ADD3: x * y + z
template <>
mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(mlir::OpBuilder& builder,
mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.add(helper.mul(operands[0], operands[1]), operands[2]);
}
/* ===================== elemwise ===================== */
mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder,
mlir::Location loc, ValueRange operands) {
auto mode = llvm::dyn_cast<dialect::Elemwise>(op).mode();
switch (mode) {
#define cb(_, _mode) \
case Mode::_mode: \
return lower_mode<Mode::_mode>(builder, loc, operands);
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb);
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb);
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb);
default:
return nullptr;
}
#undef cb
}
/* ===================== typecvt ===================== */
mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder,
mlir::Location loc, mlir::Value input) {
auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op);
megdnn::DType idtype = typecvt.idtype();
megdnn::DType odtype = typecvt.odtype();
mlir::Type itype = input.getType();
mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext());
if (mlir::FPExtOp::areCastCompatible(itype, otype)) {
return builder.create<mlir::FPExtOp>(loc, otype, input);
} else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) {
return builder.create<mlir::FPTruncOp>(loc, otype, input);
} else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and
is_signed_int_dtype(odtype)) {
return builder.create<mlir::FPToSIOp>(loc, otype, input);
} else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and
is_unsigned_int_dtype(odtype)) {
return builder.create<mlir::FPToUIOp>(loc, otype, input);
} else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and
is_signed_int_dtype(idtype)) {
return builder.create<mlir::SIToFPOp>(loc, otype, input);
} else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and
is_unsigned_int_dtype(idtype)) {
return builder.create<mlir::UIToFPOp>(loc, otype, input);
} else {
mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(),
odtype.name());
}
return nullptr;
}
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
...@@ -15,65 +15,60 @@ ...@@ -15,65 +15,60 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/jit/mlir/ir/dialect.h" #include "megdnn/opr_param_defs.h"
#include "./common.h"
#include "./numerical.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Value.h>
// clang-format off // clang-format off
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \ #define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \
cb(ReluOp, RELU) \
cb(AbsOp, ABS) \ cb(AbsOp, ABS) \
cb(NegOp, NEGATE) \
cb(AcosOp, ACOS) \ cb(AcosOp, ACOS) \
cb(AsinOp, ASIN) \ cb(AsinOp, ASIN) \
cb(CeilOp, CEIL) \ cb(CeilOp, CEIL) \
cb(CosOp, COS) \ cb(CosOp, COS) \
cb(ErfCInvOp, ERFCINV) \
cb(ErfCOp, ERFC) \
cb(ErfInvOp, ERFINV) \
cb(ErfOp, ERF) \
cb(ExpM1Op, EXPM1) \
cb(ExpOp, EXP) \ cb(ExpOp, EXP) \
cb(FastTanhOp, FAST_TANH) \
cb(FloorOp, FLOOR) \ cb(FloorOp, FLOOR) \
cb(LogOp, LOG) \ cb(HswishOp, H_SWISH) \
cb(Log1POp, LOG1P) \ cb(Log1POp, LOG1P) \
cb(LogOp, LOG) \
cb(NegOp, NEGATE) \
cb(ReluOp, RELU) \
cb(RoundOp, ROUND) \
cb(SigmoidOp, SIGMOID) \ cb(SigmoidOp, SIGMOID) \
cb(SinOp, SIN) \ cb(SinOp, SIN) \
cb(TanhOp, TANH) \ cb(TanhOp, TANH)
cb(FastTanhOp, FAST_TANH) \
cb(HswishOp, H_SWISH) \
cb(ExpM1Op, EXPM1) \
cb(RoundOp, ROUND) \
cb(ErfOp, ERF) \
cb(ErfInvOp, ERFINV) \
cb(ErfCOp, ERFC) \
cb(ErfCInvOp, ERFCINV)
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \
cb(AbsGradOp, ABS_GRAD) \ cb(AbsGradOp, ABS_GRAD) \
cb(AddOp, ADD) \ cb(AddOp, ADD) \
cb(Atan2Op, ATAN2) \
cb(EqOp, EQ) \
cb(FastTanhGradOp, FAST_TANH_GRAD) \
cb(FloorDivOp, FLOOR_DIV) \ cb(FloorDivOp, FLOOR_DIV) \
cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \
cb(FuseAddReluOp, FUSE_ADD_RELU) \
cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \
cb(FuseAddTanhOp, FUSE_ADD_TANH) \
cb(HswishGradOp, H_SWISH_GRAD) \
cb(LeqOp, LEQ) \
cb(LogSumExpOp, LOG_SUM_EXP) \
cb(LtOp, LT) \
cb(MaxOp, MAX) \ cb(MaxOp, MAX) \
cb(MinOp, MIN) \ cb(MinOp, MIN) \
cb(ModOp, MOD) \ cb(ModOp, MOD) \
cb(SubOp, SUB) \
cb(MulOp, MUL) \ cb(MulOp, MUL) \
cb(TrueDivOp, TRUE_DIV) \
cb(PowOp, POW) \ cb(PowOp, POW) \
cb(SigmoidGradOp, SIGMOID_GRAD) \ cb(SigmoidGradOp, SIGMOID_GRAD) \
cb(SubOp, SUB) \
cb(SwishGt0Op, SWITCH_GT0) \ cb(SwishGt0Op, SWITCH_GT0) \
cb(TanhGradOp, TANH_GRAD) \ cb(TanhGradOp, TANH_GRAD) \
cb(LtOp, LT) \ cb(TrueDivOp, TRUE_DIV)
cb(LeqOp, LEQ) \
cb(EqOp, EQ) \
cb(FuseAddReluOp, FUSE_ADD_RELU) \
cb(LogSumExpOp, LOG_SUM_EXP) \
cb(FuseAddTanhOp, FUSE_ADD_TANH) \
cb(FastTanhGradOp, FAST_TANH_GRAD) \
cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \
cb(HswishGradOp, H_SWISH_GRAD) \
cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \
cb(Atan2Op, ATAN2)
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \
cb(CondLeqMovOp, COND_LEQ_MOV) \ cb(CondLeqMovOp, COND_LEQ_MOV) \
...@@ -83,432 +78,19 @@ ...@@ -83,432 +78,19 @@
namespace mgb { namespace mgb {
namespace jit { namespace jit {
template <typename mgb_op> mlir::Value lower_elemwise_to_std(mlir::Operation* op,
struct StandardOp; mlir::OpBuilder& builder,
mlir::Location loc,
#define cb(mgb_op, fun) \ mlir::ValueRange operands);
template <> \
struct StandardOp<jit::mgb_op> { \
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \
ValueRange operands) { \
ValueBuilderHelper helper(builder, loc); \
return helper.fun(operands); \
} \
}
//! unary
cb(AbsOp, abs);
cb(NegOp, neg);
cb(ExpOp, exp);
cb(CosOp, cos);
cb(CeilOp, ceil);
cb(FloorOp, floor);
cb(LogOp, log);
cb(SinOp, sin);
cb(TanhOp, tanh);
//! binary
cb(AddOp, add);
cb(MaxOp, max);
cb(MinOp, min);
cb(SubOp, sub);
cb(MulOp, mul);
cb(ModOp, mod);
cb(TrueDivOp, div);
#undef cb
/////////////////////////// unary op ///////////////////////////
//! max(x, 0)
template <>
struct StandardOp<jit::ReluOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.max(operands[0], helper.const_val(0.f));
}
};
//! x * (27.f + x * x) / (27.f + 9.f * x * x);
template <>
struct StandardOp<jit::FastTanhOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto square = helper.mul(operands[0], operands[0]);
return helper.div(
helper.mul(operands[0],
helper.add(helper.const_val(27.f), square)),
helper.add(helper.const_val(27.f),
helper.mul(helper.const_val(9.f), square)));
}
};
//! x * clip(x + 3, 0, 6) / 6
template <>
struct StandardOp<jit::HswishOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto const_3 = helper.const_val(3.f);
auto const_0 = helper.const_val(0.f);
auto const_6 = helper.const_val(6.f);
auto tmp = helper.add(operands[0], const_3);
return helper.div(
helper.mul(operands[0],
helper.min(helper.max(tmp, const_0), const_6)),
const_6);
}
};
//! log(1 + p)
template <>
struct StandardOp<jit::Log1POp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.log(helper.add(operands[0], helper.const_val(1.f)));
}
};
//! 1.f / (expf(-y) + 1.f))
template <>
struct StandardOp<jit::SigmoidOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.div(helper.const_val(1.f),
helper.add(helper.exp(helper.neg(operands[0])),
helper.const_val(1.f)));
}
};
//! exp(x) - 1
template <>
struct StandardOp<jit::ExpM1Op> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.sub(helper.exp(operands[0]), helper.const_val(1.f));
}
};
template <>
struct StandardOp<jit::RoundOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.gt(operands[0], helper.const_val(0.f)),
helper.floor(helper.add(operands[0], helper.const_val(0.5f))),
helper.ceil(helper.sub(operands[0], helper.const_val(0.5f))));
}
};
//! pi / 2 - arctan2(x, sqrt(1 - x * x))
template <>
struct StandardOp<jit::AcosOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x = operands[0];
auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x));
auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
auto pi_over_2 = helper.const_val(1.57079637f);
return helper.sub(pi_over_2, asin);
}
};
//! arctan2(x, sqrt(1 - x * x))
template <>
struct StandardOp<jit::AsinOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x = operands[0];
auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x));
return atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
}
};
//! gauss error function
template <>
struct StandardOp<jit::ErfOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return erf_approx(helper, operands[0]);
}
};
//! inverse of gauss error function
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
template <>
struct StandardOp<jit::ErfInvOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sqrt2 = helper.const_val(1.4142135623f);
auto x = helper.mul(helper.const_val(0.5f),
helper.add(operands[0], helper.const_val(1.f)));
return helper.div(ndtri_approx(helper, x), sqrt2);
}
};
//! complementary error function
template <>
struct StandardOp<jit::ErfCOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.sub(helper.const_val(1.f), erf_approx(helper, operands[0]));
}
};
//! inverse of complementary gauss error function
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
template <>
struct StandardOp<jit::ErfCInvOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto minus_sqrt2 = helper.const_val(-1.4142135623f);
auto x = helper.mul(helper.const_val(0.5f), operands[0]);
return helper.div(ndtri_approx(helper, x), minus_sqrt2);
}
};
/////////////////////////// binary op ///////////////////////////
//! binary: x > 0 ? y : -y
template <>
struct StandardOp<jit::AbsGradOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.gt(operands[0], helper.const_val(0.f)),
operands[1], helper.neg(operands[1]));
}
};
//! x^y = exp(y * log(x))
template <>
struct StandardOp<jit::PowOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.exp(helper.mul(operands[1], helper.log(operands[0])));
}
};
//! x * (1 - x) * y
template <>
struct StandardOp<jit::SigmoidGradOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.mul(
helper.mul(operands[0],
helper.sub(helper.const_val(1.f), operands[0])),
operands[1]);
}
};
//! (x > 0) * y
template <>
struct StandardOp<jit::SwishGt0Op> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.gt(operands[0], helper.const_val(0.f)),
operands[1], helper.const_val(0.f));
}
};
//! (1 - x * x) * y
template <>
struct StandardOp<jit::TanhGradOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.mul(helper.sub(helper.const_val(1.0f),
helper.mul(operands[0], operands[0])),
operands[1]);
}
};
#define cb(op, fun) \
template <> \
struct StandardOp<jit::op> { \
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \
ValueRange operands) { \
ValueBuilderHelper helper(builder, loc); \
return helper.select(helper.fun(operands[0], operands[1]), \
helper.const_val(1.f), \
helper.const_val(0.f)); \
} \
}
cb(LtOp, lt);
cb(LeqOp, le);
cb(EqOp, eq);
#undef cb
//! (x + y) <= ctype(0) ? ctype(0) : (x + y)
template <>
struct StandardOp<jit::FuseAddReluOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sum = helper.add(operands[0], operands[1]);
return helper.max(sum, helper.const_val(0.f));
}
};
//! log(exp(x) + exp(y))
template <>
struct StandardOp<jit::LogSumExpOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.log(
helper.add(helper.exp(operands[0]), helper.exp(operands[1])));
}
};
//! floor(x/y)
template <>
struct StandardOp<jit::FloorDivOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.floor(helper.div(operands[0], operands[1]));
}
};
//! tanh(x + y)
template <>
struct StandardOp<jit::FuseAddTanhOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.tanh(helper.add(operands[0], operands[1]));
}
};
//! ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x * x) * y
template <>
struct StandardOp<jit::FastTanhGradOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x_pow2 = helper.mul(operands[0], operands[0]);
auto deno = helper.add(helper.const_val(3.f), x_pow2);
return helper.mul(
helper.div(
helper.add(
helper.add(
helper.div(helper.mul(helper.const_val(
-48.f),
x_pow2),
deno),
helper.const_val(27.f)),
x_pow2),
helper.mul(deno, helper.const_val(9.f))),
operands[1]);
}
};
//! 1.f / (expf(-(x+y)) + 1.f))
template <>
struct StandardOp<jit::FuseAddSigmoidOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.div(helper.const_val(1.f),
helper.add(helper.exp(helper.neg(helper.add(
operands[0], operands[1]))),
helper.const_val(1.f)));
}
};
//! x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y)
template <>
struct StandardOp<jit::HswishGradOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.lt(operands[0], helper.const_val(-3.f)),
helper.const_val(0.f),
helper.select(
helper.gt(operands[0], helper.const_val(3.f)),
operands[1],
helper.mul(
helper.div(
helper.add(helper.mul(helper.const_val(
2.f),
operands[0]),
helper.const_val(3.f)),
helper.const_val(6.f)),
operands[1])));
}
};
//! (x+y) * min(max(x + y + 3, 0), 6) * (1/6)
template <>
struct StandardOp<jit::FuseAddHswishOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sum = helper.add(operands[0], operands[1]);
auto const_3 = helper.const_val(3.f);
auto const_0 = helper.const_val(0.f);
auto const_6 = helper.const_val(6.f);
auto tmp = helper.add(sum, const_3);
return helper.div(
helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)),
const_6);
}
};
//! arctan
template <>
struct StandardOp<jit::Atan2Op> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return atan2_approx(helper, operands[0], operands[1]);
}
};
/////////////////////////// ternary op ///////////////////////////
//! x <= y ? z : ctype(0)
template <>
struct StandardOp<jit::CondLeqMovOp> {
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(helper.le(operands[0], operands[1]), operands[2],
helper.const_val(0.f));
}
};
//! x * y + z mlir::Value lower_typecvt_to_std(mlir::Operation* op,
template <> mlir::OpBuilder& builder,
struct StandardOp<jit::FuseMulAdd3Op> { mlir::Location loc,
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value input);
ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.add(helper.mul(operands[0], operands[1]), operands[2]);
}
};
} // namespace jit } // namespace jit
} // namespace mgb } // namespace mgb
#endif // MGB_JIT_MLIR #endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
/**
* \file src/jit/impl/mlir/ir/interfaces.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifndef MGB_MLIR_INTERFACES
#define MGB_MLIR_INTERFACES
#ifndef OP_BASE
include "mlir/IR/OpBase.td"
#endif
def GenericBuilderInterface : OpInterface<"GenericBuilder"> {
let methods = [
StaticInterfaceMethod<"TODO", "Type", "getResultType", (ins "ArrayRef<Value>":$operands)>,
StaticInterfaceMethod<"TODO", "Operation*", "create", (ins
"OpBuilder*":$builder,
"Location":$loc,
"ArrayRef<Value>":$operands
)>,
];
}
def ElemwiseOpInterface : OpInterface<"ElemwiseOp">;
#endif
...@@ -13,18 +13,19 @@ ...@@ -13,18 +13,19 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "./each_mode.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h" #include "megbrain/jit/mlir/ir/passes.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
#include "./each_mode.h"
#include <llvm/ADT/Sequence.h> #include <llvm/ADT/Sequence.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h> #include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
#include "mlir/IR/StandardTypes.h"
using namespace mgb; using namespace mgb;
using namespace jit; using namespace jit;
...@@ -57,41 +58,10 @@ void lower_op_to_loops(Operation* op, ValueRange operands, ...@@ -57,41 +58,10 @@ void lower_op_to_loops(Operation* op, ValueRange operands,
rewriter.replaceOp(op, alloc); rewriter.replaceOp(op, alloc);
} }
template <typename Op, typename LoweredOp> struct ElemwiseLowering : public ConversionPattern {
struct UnaryOpLowering : public ConversionPattern { ElemwiseLowering(MLIRContext* ctx)
UnaryOpLowering(MLIRContext* ctx) : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1,
: ConversionPattern(Op::getOperationName(), 1, ctx) {} ctx) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
lower_op_to_loops(
op, operands, rewriter,
[loc](OpBuilder& builder, ValueRange memref_operands,
ValueRange loop_ivs) {
typename Op::Adaptor binary_adaptor(memref_operands);
LoweredOp lower_op;
auto loaded_lhs = get_operand<AffineLoadOp>(
builder, loc, binary_adaptor.lhs(), loop_ivs);
return lower_op(builder, loc, {loaded_lhs});
});
return success();
}
};
#define cb(_op, _) \
using _op##Lowering = UnaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>;
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
#undef cb
template <typename Op, typename LoweredOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext* ctx)
: ConversionPattern(Op::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
...@@ -101,83 +71,51 @@ struct BinaryOpLowering : public ConversionPattern { ...@@ -101,83 +71,51 @@ struct BinaryOpLowering : public ConversionPattern {
dst_layout.init_contiguous_stride(); dst_layout.init_contiguous_stride();
lower_op_to_loops( lower_op_to_loops(
op, operands, rewriter, op, operands, rewriter,
[dst_layout, loc, this](OpBuilder& builder, [dst_layout, loc, op](OpBuilder& builder,
ValueRange memref_operands, ValueRange memref_operands,
ValueRange loop_ivs) { ValueRange loop_ivs) {
typename Op::Adaptor binary_adaptor(memref_operands); auto inputs = llvm::to_vector<4>(llvm::map_range(
LoweredOp lower_op; memref_operands, [&](mlir::Value val) {
return get_affine_load_op(builder, loc, val,
auto loaded_lhs = get_affine_load_op(builder, loc, loop_ivs, dst_layout);
binary_adaptor.lhs(), }));
loop_ivs, dst_layout); return lower_elemwise_to_std(op, builder, loc, inputs);
auto loaded_rhs = get_affine_load_op(builder, loc,
binary_adaptor.rhs(),
loop_ivs, dst_layout);
return lower_op(builder, loc, {loaded_lhs, loaded_rhs});
}); });
return success(); return success();
} }
}; };
#define cb(_op, _) \ struct TypeCvtLowering : public ConversionPattern {
using _op##Lowering = BinaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; TypeCvtLowering(MLIRContext* ctx)
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1,
#undef cb ctx) {}
template <typename Op, typename LoweredOp>
struct TernaryOpLowering : public ConversionPattern {
TernaryOpLowering(MLIRContext* ctx)
: ConversionPattern(Op::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>();
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type);
dst_layout.init_contiguous_stride();
lower_op_to_loops( lower_op_to_loops(
op, operands, rewriter, op, operands, rewriter,
[dst_layout, loc](OpBuilder& builder, [loc, op](OpBuilder& builder, ValueRange memref_operands,
ValueRange memref_operands, ValueRange loop_ivs) {
ValueRange loop_ivs) { mlir::Value input = get_operand<AffineLoadOp>(
typename Op::Adaptor ternary_adaptor(memref_operands); builder, loc, memref_operands[0], loop_ivs);
LoweredOp lower_op; return lower_typecvt_to_std(op, builder, loc, input);
auto loaded_x = get_affine_load_op(builder, loc,
ternary_adaptor.x(),
loop_ivs, dst_layout);
auto loaded_y = get_affine_load_op(builder, loc,
ternary_adaptor.y(),
loop_ivs, dst_layout);
auto loaded_z = get_affine_load_op(builder, loc,
ternary_adaptor.z(),
loop_ivs, dst_layout);
return lower_op(builder, loc,
{loaded_x, loaded_y, loaded_z});
}); });
return success(); return success();
} }
}; };
#define cb(_op, _) \
using _op##Lowering = \
TernaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>;
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
#undef cb
struct AssignOpLowering : public ConversionPattern { struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx) AssignOpLowering(MLIRContext* ctx)
: ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) {
}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto memref_type = operands[0].getType().cast<MemRefType>(); auto memref_type = operands[0].getType().cast<MemRefType>();
AssignOpAdaptor assign_adaptor(operands); dialect::AssignOpAdaptor assign_adaptor(operands);
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
...@@ -195,10 +133,10 @@ struct AssignOpLowering : public ConversionPattern { ...@@ -195,10 +133,10 @@ struct AssignOpLowering : public ConversionPattern {
} }
}; };
struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { struct ReturnOpLowering : public OpRewritePattern<dialect::ReturnOp> {
using OpRewritePattern<jit::ReturnOp>::OpRewritePattern; using OpRewritePattern<dialect::ReturnOp>::OpRewritePattern;
LogicalResult matchAndRewrite(jit::ReturnOp op, LogicalResult matchAndRewrite(dialect::ReturnOp op,
PatternRewriter& rewriter) const final { PatternRewriter& rewriter) const final {
// We lower "mgb.return" directly to "std.return". // We lower "mgb.return" directly to "std.return".
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
...@@ -207,12 +145,12 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { ...@@ -207,12 +145,12 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> {
}; };
struct ConstantScalarOpLowering struct ConstantScalarOpLowering
: public OpRewritePattern<jit::ConstantScalarOp> { : public OpRewritePattern<dialect::ConstantScalarOp> {
using OpRewritePattern<jit::ConstantScalarOp>::OpRewritePattern; using OpRewritePattern<dialect::ConstantScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(jit::ConstantScalarOp op, LogicalResult matchAndRewrite(dialect::ConstantScalarOp op,
PatternRewriter& rewriter) const final { PatternRewriter& rewriter) const final {
ConstantScalarOpAdaptor constant_scalar_adaptor(op); dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op);
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
op, constant_scalar_adaptor.value()); op, constant_scalar_adaptor.value());
return success(); return success();
...@@ -234,14 +172,9 @@ public: ...@@ -234,14 +172,9 @@ public:
target.addIllegalDialect<MgbDialect>(); target.addIllegalDialect<MgbDialect>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
#define cb(_op, _) _op##Lowering, patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering,
patterns.insert<MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
ReturnOpLowering,
AssignOpLowering, ConstantScalarOpLowering>( AssignOpLowering, ConstantScalarOpLowering>(
&getContext()); &getContext());
#undef cb
if (failed(applyPartialConversion(getFunction(), target, patterns))) { if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure(); signalPassFailure();
......
...@@ -13,12 +13,19 @@ ...@@ -13,12 +13,19 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "./each_mode.h" #include "./each_mode.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h" #include "megbrain/jit/mlir/ir/passes.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
#include <llvm/ADT/PointerUnion.h>
#include <llvm/ADT/Sequence.h>
#include <llvm/ADT/SetVector.h>
#include <llvm/ADT/Twine.h>
#include <llvm/IR/Type.h>
#include <mlir/Dialect/GPU/GPUDialect.h> #include <mlir/Dialect/GPU/GPUDialect.h>
#include <mlir/Dialect/SCF/SCF.h> #include <mlir/Dialect/SCF/SCF.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
...@@ -27,12 +34,6 @@ ...@@ -27,12 +34,6 @@
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
#include <llvm/ADT/PointerUnion.h>
#include <llvm/ADT/Sequence.h>
#include <llvm/ADT/SetVector.h>
#include <llvm/ADT/Twine.h>
#include <llvm/IR/Type.h>
using namespace mgb; using namespace mgb;
using namespace jit; using namespace jit;
...@@ -59,7 +60,7 @@ megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) { ...@@ -59,7 +60,7 @@ megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) {
block_iter++) { block_iter++) {
for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend();
op_iter++) { op_iter++) {
auto op = llvm::dyn_cast_or_null<AssignOp>(&(*op_iter)); auto op = llvm::dyn_cast_or_null<dialect::AssignOp>(&(*op_iter));
if (op && op.getNumOperands() > 0) { if (op && op.getNumOperands() > 0) {
return mlir_type_to_layout(*(op.operand_type_begin())); return mlir_type_to_layout(*(op.operand_type_begin()));
} }
...@@ -81,64 +82,27 @@ std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter, ...@@ -81,64 +82,27 @@ std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter,
idxs.resize(dst.ndim); idxs.resize(dst.ndim);
mlir::Value dim_index = index; mlir::Value dim_index = index;
for (int i = dst.ndim - 1; i >= 0; i--) { for (int i = dst.ndim - 1; i >= 0; i--) {
auto cur_index = helper.modI(dim_index, helper.constI(dst[i])); auto cur_index = helper.modI(dim_index, helper.const_i32(dst[i]));
idxs[i] = cur_index; idxs[i] = cur_index;
dim_index = helper.divI(dim_index, helper.constI(dst[i])); dim_index = helper.divI(dim_index, helper.const_i32(dst[i]));
} }
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
src_layout.init_contiguous_stride(); src_layout.init_contiguous_stride();
for (int i = 0; i < type.getRank(); ++i) { for (int i = 0; i < type.getRank(); ++i) {
if (src_layout[i] == 1) { if (src_layout[i] == 1) {
idxs[i] = helper.constI(0); idxs[i] = helper.const_i32(0);
} }
} }
return idxs; return idxs;
} else { } else {
return {index}; return {index};
} }
} }
template <typename Op, typename LoweredOp> struct ElemwiseLowering : public ConversionPattern {
struct UnaryOpLowering : public ConversionPattern { ElemwiseLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) : ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx),
: ConversionPattern(Op::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
typename Op::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(),
dst_layout);
auto loaded_lhs =
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index);
LoweredOp lower_op;
rewriter.replaceOp(op, lower_op(rewriter, loc, {loaded_lhs}));
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
#define cb(_op, _) \
using _op##Lowering = UnaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>;
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
#undef cb
template <typename Op, typename LoweredOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(Op::getOperationName(), 1, ctx),
m_launch_op{launch_op} {} m_launch_op{launch_op} {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
...@@ -146,23 +110,18 @@ struct BinaryOpLowering : public ConversionPattern { ...@@ -146,23 +110,18 @@ struct BinaryOpLowering : public ConversionPattern {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
typename Op::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op); auto dst_layout = output_layout(m_launch_op);
auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), auto inputs = llvm::to_vector<4>(
dst_layout); llvm::map_range(operands, [&](mlir::Value val) {
auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(), auto index =
dst_layout); get_multidim_tid(rewriter, loc, val, dst_layout);
auto loaded_lhs = get_operand<LoadOp>(rewriter, loc, return get_operand<LoadOp>(rewriter, loc, val, index);
binary_adaptor.lhs(), lhs_index); }));
auto loaded_rhs = get_operand<LoadOp>(rewriter, loc,
binary_adaptor.rhs(), rhs_index);
LoweredOp lower_op;
rewriter.replaceOp(op, rewriter.replaceOp(op,
lower_op(rewriter, loc, {loaded_lhs, loaded_rhs})); lower_elemwise_to_std(op, rewriter, loc, inputs));
return success(); return success();
} }
...@@ -170,43 +129,22 @@ private: ...@@ -170,43 +129,22 @@ private:
gpu::LaunchOp& m_launch_op; gpu::LaunchOp& m_launch_op;
}; };
#define cb(_op, _) \ struct TypeCvtLowering : public ConversionPattern {
using _op##Lowering = BinaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; TypeCvtLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) : ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx),
#undef cb
template <typename Op, typename LoweredOp>
struct TernaryOpLowering : public ConversionPattern {
TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(Op::getOperationName(), 1, ctx),
m_launch_op{launch_op} {} m_launch_op{launch_op} {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
typename Op::Adaptor ternary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op); auto dst_layout = output_layout(m_launch_op);
auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(), auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
dst_layout); auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(),
dst_layout); rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input));
auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(),
dst_layout);
auto loaded_x = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(),
index_x);
auto loaded_y = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(),
index_y);
auto loaded_z = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(),
index_z);
LoweredOp lower_op;
rewriter.replaceOp(
op, lower_op(rewriter, loc, {loaded_x, loaded_y, loaded_z}));
return success(); return success();
} }
...@@ -214,15 +152,9 @@ private: ...@@ -214,15 +152,9 @@ private:
gpu::LaunchOp& m_launch_op; gpu::LaunchOp& m_launch_op;
}; };
#define cb(_op, _) \
using _op##Lowering = \
TernaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>;
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
#undef cb
struct ReturnOpLowering : public ConversionPattern { struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx),
m_launch_op{launch_op} {} m_launch_op{launch_op} {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
...@@ -270,14 +202,14 @@ private: ...@@ -270,14 +202,14 @@ private:
}; };
struct ConstantScalarOpLowering struct ConstantScalarOpLowering
: public OpRewritePattern<jit::ConstantScalarOp> { : public OpRewritePattern<dialect::ConstantScalarOp> {
ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: OpRewritePattern<jit::ConstantScalarOp>(ctx), : OpRewritePattern<dialect::ConstantScalarOp>(ctx),
m_launch_op{launch_op} {} m_launch_op{launch_op} {}
LogicalResult matchAndRewrite(jit::ConstantScalarOp op, LogicalResult matchAndRewrite(dialect::ConstantScalarOp op,
PatternRewriter& rewriter) const final { PatternRewriter& rewriter) const final {
ConstantScalarOpAdaptor constant_scalar_adaptor(op); dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
...@@ -291,7 +223,7 @@ private: ...@@ -291,7 +223,7 @@ private:
struct AssignOpLowering : public ConversionPattern { struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), : ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx),
m_launch_op{launch_op} {} m_launch_op{launch_op} {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
...@@ -299,7 +231,7 @@ struct AssignOpLowering : public ConversionPattern { ...@@ -299,7 +231,7 @@ struct AssignOpLowering : public ConversionPattern {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
AssignOpAdaptor assign_adaptor(operands); dialect::AssignOpAdaptor assign_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op); auto dst_layout = output_layout(m_launch_op);
...@@ -343,14 +275,9 @@ public: ...@@ -343,14 +275,9 @@ public:
target.addLegalDialect<gpu::GPUDialect>(); target.addLegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<MgbDialect>(); target.addIllegalDialect<MgbDialect>();
#define cb(_op, _) _op##Lowering, patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering,
patterns.insert<MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
ReturnOpLowering,
ConstantScalarOpLowering, AssignOpLowering>( ConstantScalarOpLowering, AssignOpLowering>(
&getContext(), launch_op); &getContext(), launch_op);
#undef cb
if (failed(applyPartialConversion(func_op, target, patterns))) { if (failed(applyPartialConversion(func_op, target, patterns))) {
signalPassFailure(); signalPassFailure();
......
...@@ -22,7 +22,7 @@ mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, ...@@ -22,7 +22,7 @@ mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x,
std::vector<mlir::Value>& coeff) { std::vector<mlir::Value>& coeff) {
size_t n = coeff.size(); size_t n = coeff.size();
if (n == 0) { if (n == 0) {
return helper.const_val(0); return helper.const_f32(0);
} }
mlir::Value r = coeff[0]; mlir::Value r = coeff[0];
...@@ -40,23 +40,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, ...@@ -40,23 +40,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y,
mlir::Value x) { mlir::Value x) {
auto atan_poly = [&](mlir::Value t) { auto atan_poly = [&](mlir::Value t) {
std::vector<mlir::Value> coeff = { std::vector<mlir::Value> coeff = {
helper.const_val(2.90188402868807315826416015625E-3), helper.const_f32(2.90188402868807315826416015625E-3),
helper.const_val(-1.62907354533672332763671875E-2), helper.const_f32(-1.62907354533672332763671875E-2),
helper.const_val(4.3082617223262786865234375E-2), helper.const_f32(4.3082617223262786865234375E-2),
helper.const_val(-7.5408883392810821533203125E-2), helper.const_f32(-7.5408883392810821533203125E-2),
helper.const_val(0.1066047251224517822265625), helper.const_f32(0.1066047251224517822265625),
helper.const_val(-0.14209578931331634521484375), helper.const_f32(-0.14209578931331634521484375),
helper.const_val(0.19993579387664794921875), helper.const_f32(0.19993579387664794921875),
helper.const_val(-0.3333314359188079833984375)}; helper.const_f32(-0.3333314359188079833984375)};
auto t2 = helper.mul(t, t); auto t2 = helper.mul(t, t);
auto p = polynomial(helper, t2, coeff); auto p = polynomial(helper, t2, coeff);
return helper.add(helper.mul(helper.mul(p, t2), t), t); return helper.add(helper.mul(helper.mul(p, t2), t), t);
}; };
// constants // constants
auto zero = helper.const_val(0); auto zero = helper.const_f32(0);
auto pi = helper.const_val(3.141592653589793); auto pi = helper.const_f32(3.141592653589793);
auto pi_over_2 = helper.const_val(1.570796326794897); auto pi_over_2 = helper.const_f32(1.570796326794897);
// transform the angle into interval [0, pi/4] // transform the angle into interval [0, pi/4]
auto ax = helper.abs(x); auto ax = helper.abs(x);
...@@ -83,23 +83,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, ...@@ -83,23 +83,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y,
// original book: // original book:
// Numerical Recipes in Fortran 77: The Art of Scientific Computing // Numerical Recipes in Fortran 77: The Art of Scientific Computing
mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) {
auto zero = helper.const_val(0); auto zero = helper.const_f32(0);
auto one = helper.const_val(1); auto one = helper.const_f32(1);
auto half = helper.const_val(0.5); auto half = helper.const_f32(0.5);
auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x))));
std::vector<mlir::Value> coeff = { std::vector<mlir::Value> coeff = {
helper.const_val(0.17087277), helper.const_f32(0.17087277),
helper.const_val(-0.82215223), helper.const_f32(-0.82215223),
helper.const_val(1.48851587), helper.const_f32(1.48851587),
helper.const_val(-1.13520398), helper.const_f32(-1.13520398),
helper.const_val(0.27886807), helper.const_f32(0.27886807),
helper.const_val(-0.18628806), helper.const_f32(-0.18628806),
helper.const_val(0.09678418), helper.const_f32(0.09678418),
helper.const_val(0.37409196), helper.const_f32(0.37409196),
helper.const_val(1.00002368), helper.const_f32(1.00002368),
helper.const_val(-1.26551223)}; helper.const_f32(-1.26551223)};
auto p = polynomial(helper, t, coeff); auto p = polynomial(helper, t, coeff);
auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x))));
...@@ -130,25 +130,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { ...@@ -130,25 +130,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) {
// polynomial P // polynomial P
auto P = [&](mlir::Value i, mlir::Value cond) { auto P = [&](mlir::Value i, mlir::Value cond) {
std::vector<mlir::Value> coeff0 = { std::vector<mlir::Value> coeff0 = {
helper.const_val(4.05544892305962419923E0), helper.const_f32(4.05544892305962419923E0),
helper.const_val(3.15251094599893866154E1), helper.const_f32(3.15251094599893866154E1),
helper.const_val(5.71628192246421288162E1), helper.const_f32(5.71628192246421288162E1),
helper.const_val(4.40805073893200834700E1), helper.const_f32(4.40805073893200834700E1),
helper.const_val(1.46849561928858024014E1), helper.const_f32(1.46849561928858024014E1),
helper.const_val(2.18663306850790267539E0), helper.const_f32(2.18663306850790267539E0),
helper.const_val(-1.40256079171354495875E-1), helper.const_f32(-1.40256079171354495875E-1),
helper.const_val(-3.50424626827848203418E-2), helper.const_f32(-3.50424626827848203418E-2),
helper.const_val(-8.57456785154685413611E-4)}; helper.const_f32(-8.57456785154685413611E-4)};
std::vector<mlir::Value> coeff1 = { std::vector<mlir::Value> coeff1 = {
helper.const_val(3.23774891776946035970E0), helper.const_f32(3.23774891776946035970E0),
helper.const_val(6.91522889068984211695E0), helper.const_f32(6.91522889068984211695E0),
helper.const_val(3.93881025292474443415E0), helper.const_f32(3.93881025292474443415E0),
helper.const_val(1.33303460815807542389E0), helper.const_f32(1.33303460815807542389E0),
helper.const_val(2.01485389549179081538E-1), helper.const_f32(2.01485389549179081538E-1),
helper.const_val(1.23716634817820021358E-2), helper.const_f32(1.23716634817820021358E-2),
helper.const_val(3.01581553508235416007E-4), helper.const_f32(3.01581553508235416007E-4),
helper.const_val(2.65806974686737550832E-6), helper.const_f32(2.65806974686737550832E-6),
helper.const_val(6.23974539184983293730E-9)}; helper.const_f32(6.23974539184983293730E-9)};
return helper.select(cond, return helper.select(cond,
polynomial(helper, i, coeff0), polynomial(helper, i, coeff0),
polynomial(helper, i, coeff1)); polynomial(helper, i, coeff1));
...@@ -157,25 +157,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { ...@@ -157,25 +157,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) {
// polynomial Q // polynomial Q
auto Q = [&](mlir::Value i, mlir::Value cond) { auto Q = [&](mlir::Value i, mlir::Value cond) {
std::vector<mlir::Value> coeff0 = { std::vector<mlir::Value> coeff0 = {
helper.const_val(1.f), helper.const_f32(1.f),
helper.const_val(1.57799883256466749731E1), helper.const_f32(1.57799883256466749731E1),
helper.const_val(4.53907635128879210584E1), helper.const_f32(4.53907635128879210584E1),
helper.const_val(4.13172038254672030440E1), helper.const_f32(4.13172038254672030440E1),
helper.const_val(1.50425385692907503408E1), helper.const_f32(1.50425385692907503408E1),
helper.const_val(2.50464946208309415979E0), helper.const_f32(2.50464946208309415979E0),
helper.const_val(-1.42182922854787788574E-1), helper.const_f32(-1.42182922854787788574E-1),
helper.const_val(-3.80806407691578277194E-2), helper.const_f32(-3.80806407691578277194E-2),
helper.const_val(-9.33259480895457427372E-4)}; helper.const_f32(-9.33259480895457427372E-4)};
std::vector<mlir::Value> coeff1 = { std::vector<mlir::Value> coeff1 = {
helper.const_val(1.f), helper.const_f32(1.f),
helper.const_val(6.02427039364742014255E0), helper.const_f32(6.02427039364742014255E0),
helper.const_val(3.67983563856160859403E0), helper.const_f32(3.67983563856160859403E0),
helper.const_val(1.37702099489081330271E0), helper.const_f32(1.37702099489081330271E0),
helper.const_val(2.16236993594496635890E-1), helper.const_f32(2.16236993594496635890E-1),
helper.const_val(1.34204006088543189037E-2), helper.const_f32(1.34204006088543189037E-2),
helper.const_val(3.28014464682127739104E-4), helper.const_f32(3.28014464682127739104E-4),
helper.const_val(2.89247864745380683936E-6), helper.const_f32(2.89247864745380683936E-6),
helper.const_val(6.79019408009981274425E-9)}; helper.const_f32(6.79019408009981274425E-9)};
return helper.select(cond, return helper.select(cond,
polynomial(helper, i, coeff0), polynomial(helper, i, coeff0),
polynomial(helper, i, coeff1)); polynomial(helper, i, coeff1));
...@@ -184,37 +184,37 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { ...@@ -184,37 +184,37 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) {
// polynomial R // polynomial R
auto R = [&](mlir::Value i) { auto R = [&](mlir::Value i) {
std::vector<mlir::Value> coeff = { std::vector<mlir::Value> coeff = {
helper.const_val(-5.99633501014107895267E1), helper.const_f32(-5.99633501014107895267E1),
helper.const_val(9.80010754185999661536E1), helper.const_f32(9.80010754185999661536E1),
helper.const_val(-5.66762857469070293439E1), helper.const_f32(-5.66762857469070293439E1),
helper.const_val(1.39312609387279679503E1), helper.const_f32(1.39312609387279679503E1),
helper.const_val(-1.23916583867381258016E0)}; helper.const_f32(-1.23916583867381258016E0)};
return polynomial(helper, i, coeff); return polynomial(helper, i, coeff);
}; };
// polynomial S // polynomial S
auto S = [&](mlir::Value i) { auto S = [&](mlir::Value i) {
std::vector<mlir::Value> coeff = { std::vector<mlir::Value> coeff = {
helper.const_val(1.f), helper.const_f32(1.f),
helper.const_val(1.95448858338141759834E0), helper.const_f32(1.95448858338141759834E0),
helper.const_val(4.67627912898881538453E0), helper.const_f32(4.67627912898881538453E0),
helper.const_val(8.63602421390890590575E1), helper.const_f32(8.63602421390890590575E1),
helper.const_val(-2.25462687854119370527E2), helper.const_f32(-2.25462687854119370527E2),
helper.const_val(2.00260212380060660359E2), helper.const_f32(2.00260212380060660359E2),
helper.const_val(-8.20372256168333339912E1), helper.const_f32(-8.20372256168333339912E1),
helper.const_val(1.59056225126211695515E1), helper.const_f32(1.59056225126211695515E1),
helper.const_val(-1.18331621121330003142E0)}; helper.const_f32(-1.18331621121330003142E0)};
return polynomial(helper, i, coeff); return polynomial(helper, i, coeff);
}; };
// constants // constants
auto zero = helper.const_val(0); auto zero = helper.const_f32(0);
auto one = helper.const_val(1); auto one = helper.const_f32(1);
auto half = helper.const_val(0.5); auto half = helper.const_f32(0.5);
auto eight = helper.const_val(8); auto eight = helper.const_f32(8);
auto minus_2 = helper.const_val(-2); auto minus_2 = helper.const_f32(-2);
auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2) auto exp_minus_2 = helper.const_f32(0.135335283236); // exp(-2)
auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi) auto sqrt_2pi = helper.const_f32(2.506628274631); // sqrt(2pi)
// conditions // conditions
auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2)
......
/**
* \file src/jit/impl/mlir/ir/ops.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifndef MGB_MLIR_OPS
#define MGB_MLIR_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "./interfaces.td"
include "./predicates.td"
def Mgb_Dialect : Dialect {
let name = "mgb";
let cppNamespace = "mgb::jit";
}
class ElemwiseBuilderImpl {
code ElemwiseBuilderImpl_create = [{
static Operation* create(OpBuilder* builder, Location loc, ValueRange operands) {
OperationState state(loc, getOperationName());
state.addOperands(operands);
state.addTypes(getResultType(operands));
return builder->createOperation(state);
}
}];
}
class ElemwiseOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> :
Op<Mgb_Dialect, mnemonic, !listconcat(traits, [ElemwiseOpInterface,
GenericBuilderInterface])>, ElemwiseBuilderImpl;
class GenericOp<string mnemonic, list<OpTrait> traits = []> :
Op<Mgb_Dialect, mnemonic, traits>;
class ElemwiseUnaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> :
ElemwiseOp<mnemonic, traits> {
let arguments = (ins F32MemRef:$lhs);
let results = (outs F32MemRef);
let builders = [OpBuilder<
"Builder* builder, OperationState& result, ValueRange operands", [{
result.addOperands(operands);
result.addTypes(getResultType(operands));
}]>, OpBuilder <
"OpBuilder& builder, OperationState& result, Value lhs", [{
result.addOperands(lhs);
result.addTypes(getResultType({lhs}));
}]
>];
let extraClassDeclaration = [{
static Type getResultType(ValueRange operands) {
return deduce_result_type(operands);
}
}] # ElemwiseBuilderImpl_create;
}
def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>;
def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>;
def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>;
def AcosOp : ElemwiseUnaryOp<"acos", [NoSideEffect]>;
def AsinOp : ElemwiseUnaryOp<"asin", [NoSideEffect]>;
def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>;
def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>;
def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>;
def ExpM1Op : ElemwiseUnaryOp<"expm1", [NoSideEffect]>;
def FloorOp : ElemwiseUnaryOp<"floor", [NoSideEffect]>;
def LogOp : ElemwiseUnaryOp<"log", [NoSideEffect]>;
def Log1POp : ElemwiseUnaryOp<"log1p", [NoSideEffect]>;
def SigmoidOp: ElemwiseUnaryOp<"sigmoid", [NoSideEffect]>;
def SinOp : ElemwiseUnaryOp<"sin", [NoSideEffect]>;
def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>;
def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>;
def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>;
def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>;
def ErfOp : ElemwiseUnaryOp<"erf", [NoSideEffect]>;
def ErfInvOp : ElemwiseUnaryOp<"erfinv", [NoSideEffect]>;
def ErfCOp : ElemwiseUnaryOp<"erfc", [NoSideEffect]>;
def ErfCInvOp : ElemwiseUnaryOp<"erfcinv", [NoSideEffect]>;
class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> :
ElemwiseOp<mnemonic, traits> {
let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs);
let results = (outs F32MemRef);
let builders = [OpBuilder<
"Builder* builder, OperationState& result, ValueRange operands", [{
result.addOperands(operands);
result.addTypes(getResultType(operands));
}]
>, OpBuilder <
"OpBuilder& builder, OperationState& result, Value lhs, Value rhs", [{
result.addOperands(lhs);
result.addOperands(rhs);
result.addTypes(getResultType({lhs, rhs}));
}]
>];
let extraClassDeclaration = [{
static Type getResultType(ValueRange operands) {
return deduce_result_type(operands);
}
}] # ElemwiseBuilderImpl_create;
}
def AbsGradOp : ElemwiseBinaryOp<"abs_grad", [NoSideEffect]>;
def AddOp : ElemwiseBinaryOp<"add", [Commutative, NoSideEffect]>;
def FloorDivOp : ElemwiseBinaryOp<"floor_div", [NoSideEffect]>;
def MaxOp : ElemwiseBinaryOp<"max", [Commutative, NoSideEffect]>;
def MinOp : ElemwiseBinaryOp<"min", [Commutative, NoSideEffect]>;
def ModOp : ElemwiseBinaryOp<"mod", [NoSideEffect]>;
def MulOp : ElemwiseBinaryOp<"mul", [Commutative, NoSideEffect]>;
def SubOp : ElemwiseBinaryOp<"sub", [NoSideEffect]>;
def SigmoidGradOp : ElemwiseBinaryOp<"sigmoid_grad", [NoSideEffect]>;
def SwishGt0Op : ElemwiseBinaryOp<"switch_gt0", [NoSideEffect]>;
def TanhGradOp : ElemwiseBinaryOp<"tanh_grad", [NoSideEffect]>;
def LtOp : ElemwiseBinaryOp<"lt", [NoSideEffect]>;
def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>;
def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>;
def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>;
def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>;
def PowOp : ElemwiseBinaryOp<"pow", [NoSideEffect]>;
def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>;
def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>;
def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>;
def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>;
def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>;
def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>;
def Atan2Op : ElemwiseBinaryOp<"atan2", [NoSideEffect]>;
class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> :
ElemwiseOp<mnemonic, traits> {
let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z);
let results = (outs F32MemRef);
let builders = [OpBuilder<
"Builder* builder, OperationState& result, ValueRange operands", [{
result.addOperands(operands);
result.addTypes(getResultType(operands));
}]
>, OpBuilder <
"OpBuilder& builder, OperationState& result, Value x, Value y, Value z", [{
result.addOperands(x);
result.addOperands(y);
result.addOperands(z);
result.addTypes(getResultType({x, y, z}));
}]
>];
let extraClassDeclaration = [{
static Type getResultType(ValueRange operands) {
return deduce_result_type(operands);
}
}] # ElemwiseBuilderImpl_create;
}
def CondLeqMovOp: ElemwiseTernaryOp<"cond_leq_mov", [NoSideEffect]>;
def FuseMulAdd3Op: ElemwiseTernaryOp<"fuse_mul_add3", [NoSideEffect]>;
def ReturnOp : GenericOp<"return",
[NoSideEffect, HasParent<"FuncOp">, Terminator]> {
let summary = "return operation";
let description = [{
The "return" operation represents a return operation within a function.
The operation takes an no tensor operand and produces no results.
}];
// The return operation takes an optional input operand to return. This
// value must match the return type of the enclosing function.
let arguments = (ins);
// The return operation only emits the input in the format if it is present.
let assemblyFormat = "attr-dict";
}
def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> {
let summary = "scalar constant";
let arguments = (ins AnyAttr:$value);
let results = (outs F32:$result);
let builders = [OpBuilder<
"Builder* builder, OperationState& result, float value", [{
result.addAttribute("value", builder->getF32FloatAttr(value));
result.addTypes(builder->getF32Type());
}]
>];
let extraClassDeclaration = [{
Attribute getValue() { return getAttr("value"); }
FloatAttr getFloatAttr() { return getAttrOfType<FloatAttr>("value"); }
}];
}
def AssignOp : GenericOp<"assign", []> {
let summary = "assign op";
let description = [{
assign rhs to lhs without results
}];
let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs);
}
#endif
/**
* \file src/jit/impl/mlir/ir/predicates.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifndef MGB_MLIR_PREDICATES
#define MGB_MLIR_PREDICATES
#ifndef OP_BASE
include "mlir/IR/OpBase.td"
#endif
def ElemwiseFloatAny : TypeConstraint<
CPred<"is_elemwise_float($_self)">, "elemwise-float">;
#endif
/**
* \file src/jit/impl/mlir/ir/types.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "./types.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/jit/mlir/ir/utils.h"
namespace mgb {
namespace jit {
mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type,
mlir::MLIRContext* ctx) {
switch (type.enumv()) {
case megdnn::DTypeEnum::Float32:
return mlir::FloatType::getF32(ctx);
case megdnn::DTypeEnum::Uint8:
return mlir::IntegerType::get(8, ctx);
case megdnn::DTypeEnum::Int8:
return mlir::IntegerType::get(8, ctx);
case megdnn::DTypeEnum::Int16:
return mlir::IntegerType::get(16, ctx);
case megdnn::DTypeEnum::Int32:
return mlir::IntegerType::get(32, ctx);
case megdnn::DTypeEnum::IntB1:
return mlir::IntegerType::get(1, ctx);
case megdnn::DTypeEnum::IntB2:
return mlir::IntegerType::get(2, ctx);
case megdnn::DTypeEnum::IntB4:
return mlir::IntegerType::get(4, ctx);
case megdnn::DTypeEnum::Byte:
return mlir::IntegerType::get(8, ctx);
case megdnn::DTypeEnum::Float16:
return mlir::FloatType::getF16(ctx);
case megdnn::DTypeEnum::UintB4:
return mlir::IntegerType::get(4, ctx);
case megdnn::DTypeEnum::BFloat16:
return mlir::FloatType::getBF16(ctx);
case megdnn::DTypeEnum::Bool:
return mlir::IntegerType::get(1, ctx);
default:
mgb_throw(InternalError, "Unsupported MegDNN dtype: %s",
type.name());
}
}
megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) {
mlir::Type element_type = type;
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) {
element_type = cast.getElementType();
}
megdnn::DTypeEnum enumv;
if (element_type.isF32()) {
enumv = megdnn::DTypeEnum::Float32;
} else if (element_type.isSignlessInteger(1)) {
enumv = megdnn::DTypeEnum::IntB1;
} else if (element_type.isSignlessInteger(2)) {
enumv = megdnn::DTypeEnum::IntB2;
} else if (element_type.isSignlessInteger(4)) {
enumv = megdnn::DTypeEnum::IntB4;
} else if (element_type.isSignlessInteger(8)) {
enumv = megdnn::DTypeEnum::Int8;
} else if (element_type.isSignlessInteger(16)) {
enumv = megdnn::DTypeEnum::Int16;
} else if (element_type.isSignlessInteger(32)) {
enumv = megdnn::DTypeEnum::Int32;
} else if (element_type.isF16()) {
enumv = megdnn::DTypeEnum::Float16;
} else if (element_type.isBF16()) {
enumv = megdnn::DTypeEnum::BFloat16;
} else if (element_type.isSignlessInteger(1)) {
enumv = megdnn::DTypeEnum::Bool;
} else {
mgb_throw(InternalError, "Unsupported MLIR Type: %s",
mlir_type_to_string(element_type).c_str());
}
return megdnn::DType::from_enum(enumv);
}
bool is_signed_int_dtype(megdnn::DType type) {
auto enumv = type.enumv();
return enumv == megdnn::DTypeEnum::Int8 or
enumv == megdnn::DTypeEnum::Int16 or
enumv == megdnn::DTypeEnum::Int32 or
enumv == megdnn::DTypeEnum::IntB1 or
enumv == megdnn::DTypeEnum::IntB2 or
enumv == megdnn::DTypeEnum::IntB4;
}
bool is_unsigned_int_dtype(megdnn::DType type) {
auto enumv = type.enumv();
return enumv == megdnn::DTypeEnum::Uint8 or
enumv == megdnn::DTypeEnum::UintB4;
}
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
...@@ -14,22 +14,33 @@ ...@@ -14,22 +14,33 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "megdnn/dtype.h"
#include <mlir/IR/StandardTypes.h> #include <mlir/IR/StandardTypes.h>
namespace mgb { namespace mgb {
namespace jit { namespace jit {
inline bool is_elemwise_float(const mlir::Type& dt) { #define FOR_EACH_DNN_DTYPE(cb) \
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { cb(Float32, dt_float32); \
if (cast.getElementType().isF32()) { cb(Uint8, dt_uint8); \
return true; cb(Int8, dt_int8); \
} cb(Int16, dt_int16); \
} cb(Int32, dt_int32); \
if (dt.isa<mlir::FloatType>()) { cb(Byte, dt_byte); \
return true; MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); \
} MEGDNN_INC_FLOAT16(cb(BFloat16, dt_bfloat16)); \
return false; cb(Bool, dt_bool);
}
mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type,
mlir::MLIRContext* ctx);
megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type);
bool is_signed_int_dtype(megdnn::DType type);
bool is_unsigned_int_dtype(megdnn::DType type);
} // namespace jit } // namespace jit
} // namespace mgb } // namespace mgb
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/jit/mlir/ir/utils.h"
#include "./types.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include "megdnn/oprs/general.h"
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/oprs/general.h"
#include <mlir/Dialect/Affine/IR/AffineOps.h> #include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
...@@ -44,7 +47,7 @@ mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, ...@@ -44,7 +47,7 @@ mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type,
return alloc; return alloc;
} }
mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { mlir::Type jit::deduce_elemwise_res_type(mlir::ValueRange operands) {
megdnn::TensorShapeArray srcs; megdnn::TensorShapeArray srcs;
megdnn::TensorShape dst; megdnn::TensorShape dst;
megdnn::DType dst_type; megdnn::DType dst_type;
...@@ -59,8 +62,8 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { ...@@ -59,8 +62,8 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) {
} }
megdnn::Elemwise::deduce_shape(srcs, dst); megdnn::Elemwise::deduce_shape(srcs, dst);
mlir::Builder builder(operands[0].getContext()); mlir::Builder builder(operands[0].getContext());
return layout_to_mlir_type({dst, mlir_type_to_dtype(operands[0].getType())}, return layout_to_mlir_type(
builder); {dst, mlir_type_to_megdnn_dtype(operands[0].getType())}, builder);
} }
megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) {
...@@ -72,41 +75,21 @@ megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { ...@@ -72,41 +75,21 @@ megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) {
for (size_t i = 0; i < ret.ndim; i++) { for (size_t i = 0; i < ret.ndim; i++) {
ret.shape[i] = real_type.getDimSize(i); ret.shape[i] = real_type.getDimSize(i);
} }
ret.dtype = mlir_type_to_dtype(real_type.getElementType()); ret.dtype = mlir_type_to_megdnn_dtype(real_type.getElementType());
} }
return ret; return ret;
} }
megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) {
mlir::Type element_type = type;
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) {
element_type = cast.getElementType();
}
if (element_type.isF32()) {
return megdnn::dtype::Float32{};
} else {
mgb_throw(InternalError,
"Unsupport mlir type for MemRefType, got: %s\n",
mlir_type_to_string(type).c_str());
}
return {};
}
mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout,
mlir::Builder& builder) { mlir::Builder& builder) {
std::vector<int64_t> shape; std::vector<int64_t> shape;
for (size_t i = 0; i < layout.ndim; i++) { for (size_t i = 0; i < layout.ndim; i++) {
shape.push_back(layout[i]); shape.push_back(layout[i]);
} }
switch (layout.dtype.enumv()) { mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext());
case megdnn::DTypeEnum::Float32: return mlir::MemRefType::get(shape, type);
return mlir::MemRefType::get(shape, builder.getF32Type());
default:
mgb_throw(InternalError, "No supported dtype: %s",
layout.dtype.name());
}
} }
#endif // MGB_JIT_MLIR #endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "./mlir_gen.h" #include "./mlir_gen.h"
#include "./ir/each_mode.h" #include "./ir/each_mode.h"
#include "./ir/types.h"
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
...@@ -116,9 +117,9 @@ private: ...@@ -116,9 +117,9 @@ private:
return nullptr; return nullptr;
} }
jit::ReturnOp return_op; dialect::ReturnOp return_op;
if (!return_op) { if (!return_op) {
m_builder.create<jit::ReturnOp>(m_builder.getUnknownLoc()); m_builder.create<dialect::ReturnOp>(m_builder.getUnknownLoc());
} }
std::string op_content = mlir_type_to_string(func_op); std::string op_content = mlir_type_to_string(func_op);
func_op.setName( func_op.setName(
...@@ -135,9 +136,7 @@ private: ...@@ -135,9 +136,7 @@ private:
cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
if (opr->same_type<JITPlaceholder>()) { if (opr->same_type<JITPlaceholder>()) {
return; return;
} } else if (opr->same_type<opr::ImmutableTensor>()) {
if (opr->same_type<opr::ImmutableTensor>()) {
auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
if (imm.valid()) { if (imm.valid()) {
auto dtype = imm->dtype(); auto dtype = imm->dtype();
...@@ -150,59 +149,53 @@ private: ...@@ -150,59 +149,53 @@ private:
"dtype, but got %s", "dtype, but got %s",
dtype.name()); dtype.name());
} }
auto&& out = m_builder.create<jit::ConstantScalarOp>( auto&& out = m_builder.create<dialect::ConstantScalarOp>(
m_builder.getUnknownLoc(), m_builder.getF32Type(), m_builder.getUnknownLoc(), m_builder.getF32Type(),
m_builder.getF32FloatAttr(scalar_value)); m_builder.getF32FloatAttr(scalar_value));
mgb_assert(mlir::succeeded( mgb_assert(mlir::succeeded(
declare(opr->output(0)->name(), out))); declare(opr->output(0)->name(), out)));
} }
} } else if (opr->same_type<opr::Elemwise>()) {
auto&& out = gen_elemwise(opr->cast_final<opr::Elemwise>());
if (opr->same_type<opr::Elemwise>()) { mgb_assert(
auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); mlir::succeeded(declare(opr->output(0)->name(), out)));
return;
} else if (opr->same_type<opr::TypeCvt>()) {
auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>());
mgb_assert( mgb_assert(
mlir::succeeded(declare(opr->output(0)->name(), out))); mlir::succeeded(declare(opr->output(0)->name(), out)));
} }
}} }}
.add(internal_graph.output()); .add(internal_graph.output());
m_builder.create<AssignOp>(m_builder.getUnknownLoc(), m_builder.create<dialect::AssignOp>(m_builder.getUnknownLoc(),
get(internal_graph.output()), get(internal_graph.output()),
get(args.outputs[0].from)); get(args.outputs[0].from));
return mlir::success(); return mlir::success();
} }
mlir::Value gen_op(const opr::Elemwise& opr) { mlir::Value gen_elemwise(const opr::Elemwise& opr) {
switch (opr.param().mode) { llvm::SmallVector<mlir::Value, 4> operands;
#define cb(mlir_op, mgb_mode) \ for (size_t i = 0; i < opr.input().size(); i++) {
case opr::Elemwise::Mode::mgb_mode: \ operands.push_back(get(opr.input(i)));
return m_builder.create<jit::mlir_op>(m_builder.getUnknownLoc(), \
get(opr.input(0)), \
get(opr.input(1))); \
break;
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
#undef cb
#define cb(mlir_op, mgb_mode) \
case opr::Elemwise::Mode::mgb_mode: \
return m_builder.create<jit::mlir_op>(m_builder.getUnknownLoc(), \
get(opr.input(0))); \
break;
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
#undef cb
#define cb(mlir_op, mgb_mode) \
case opr::Elemwise::Mode::mgb_mode: \
return m_builder.create<jit::mlir_op>( \
m_builder.getUnknownLoc(), get(opr.input(0)), \
get(opr.input(1)), get(opr.input(2))); \
break;
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
#undef cb
default:
return nullptr;
} }
return nullptr; mlir::Type res_type = deduce_elemwise_res_type(operands);
return m_builder.create<dialect::Elemwise>(
m_builder.getUnknownLoc(), res_type, mlir::ValueRange(operands),
opr.param().mode);
}
mlir::Value gen_typecvt(const opr::TypeCvt& opr) {
auto shape = get(opr.input(0))
.getType()
.dyn_cast_or_null<mlir::MemRefType>()
.getShape();
auto res_type = mlir::MemRefType::get(
shape,
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext()));
return m_builder.create<dialect::TypeCvt>(
m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
opr.input(0)->dtype(), opr.param());
} }
mlir::Type get_type(const TensorLayout& layout) { mlir::Type get_type(const TensorLayout& layout) {
......
# mgb_dialect
set(LLVM_TARGET_DEFINITIONS mgb_dialect.td)
tablegen(MLIR mgb_dialect.h.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-decls")
tablegen(MLIR mgb_dialect.cpp.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-defs")
add_custom_target(mgb_dialect DEPENDS mgb_dialect.h.inc mgb_dialect.cpp.inc)
add_dependencies(mgb_dialect param_defs_tblgen)
/** /**
* \file src/jit/impl/mlir/ir/dialect.h * \file src/jit/include/megbrain/jit/mlir/ir/dialect.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/jit/mlir/ir/interfaces.h" #include "megdnn/opr_param_defs.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h> #include <mlir/IR/Function.h>
...@@ -39,7 +38,7 @@ public: ...@@ -39,7 +38,7 @@ public:
#define GET_OP_CLASSES #define GET_OP_CLASSES
using namespace mlir; using namespace mlir;
#include "megbrain/jit/mlir/ir/ops.h.inc" #include "megbrain/jit/mlir/ir/mgb_dialect.h.inc"
#endif // MGB_JIT && MGB_JIT_MLIR #endif // MGB_JIT && MGB_JIT_MLIR
......
/**
* \file src/jit/include/mlir/ir/interfaces.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT_MLIR
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/Types.h>
namespace mlir {
/// Include the auto-generated declarations.
#include "megbrain/jit/mlir/ir/interfaces.h.inc"
}
#endif // MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
/** /**
* \file src/jit/impl/mlir/ir/passes.h * \file src/jit/include/megbrain/jit/mlir/ir/passes.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
*/ */
#pragma once #pragma once
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR
#include <memory> #include <memory>
......
/** /**
* \file src/jit/include/megbrain/mlir/ir/utils.h * \file src/jit/include/megbrain/jit/mlir/ir/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -35,15 +35,19 @@ std::string mlir_type_to_string(T&& t) { ...@@ -35,15 +35,19 @@ std::string mlir_type_to_string(T&& t) {
mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc,
mlir::PatternRewriter& rewriter); mlir::PatternRewriter& rewriter);
mlir::Type deduce_result_type(mlir::ValueRange operands); mlir::Type deduce_elemwise_res_type(mlir::ValueRange operands);
/** /**
* \brief convert mlir type to TensorShape * \brief convert MLIR Type to TensorLayout
*/ */
megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); megdnn::TensorLayout mlir_type_to_layout(mlir::Type type);
megdnn::DType mlir_type_to_dtype(mlir::Type type);
/**
* \brief convert TensorLayout to MLIR Type
*/
mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout,
mlir::Builder& builder); mlir::Builder& builder);
} // namespace jit } // namespace jit
} // namespace mgb } // namespace mgb
......
...@@ -267,6 +267,8 @@ void run_mlir_mode(CompNode cn) { ...@@ -267,6 +267,8 @@ void run_mlir_mode(CompNode cn) {
} // anonymous namespace } // anonymous namespace
/* ===================== TestJITHalideCodeGenCude ===================== */
#if MGB_JIT_HALIDE #if MGB_JIT_HALIDE
template <typename tag> template <typename tag>
class TestJITHalideCodeGenCuda : public ::testing::Test {}; class TestJITHalideCodeGenCuda : public ::testing::Test {};
...@@ -277,6 +279,8 @@ TYPED_TEST(TestJITHalideCodeGenCuda, run) { ...@@ -277,6 +279,8 @@ TYPED_TEST(TestJITHalideCodeGenCuda, run) {
} }
#endif #endif
/* ===================== TestJITNvrtcCodeGen ===================== */
template <typename tag> template <typename tag>
class TestJITNvrtcCodeGen : public ::testing::Test {}; class TestJITNvrtcCodeGen : public ::testing::Test {};
TYPED_TEST_CASE(TestJITNvrtcCodeGen, test_types); TYPED_TEST_CASE(TestJITNvrtcCodeGen, test_types);
...@@ -285,6 +289,8 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { ...@@ -285,6 +289,8 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) {
run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0")); run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0"));
} }
/* ===================== TestJITMlirCodeGen ===================== */
#if MGB_JIT_MLIR #if MGB_JIT_MLIR
TEST(TestJITMlirCodeGen, Basic) { TEST(TestJITMlirCodeGen, Basic) {
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");
...@@ -299,7 +305,8 @@ TEST(TestJITMlirCodeGen, BasicGPU) { ...@@ -299,7 +305,8 @@ TEST(TestJITMlirCodeGen, BasicGPU) {
run_mlir_broadcast(cn); run_mlir_broadcast(cn);
} }
///////////////////////// unary /////////////////////////////// /* ===================== TestJITMlirUnaryElemwise ===================== */
// clang-format off // clang-format off
#define FOREACH_UNARY_MODE(cb) \ #define FOREACH_UNARY_MODE(cb) \
cb(RELU) \ cb(RELU) \
...@@ -365,7 +372,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { ...@@ -365,7 +372,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) {
run_mlir_mode<TypeParam, 1>(cn); run_mlir_mode<TypeParam, 1>(cn);
} }
///////////////////////// binary /////////////////////////////// /* ===================== TestJITMlirBinaryElemwise ===================== */
// clang-format off // clang-format off
#define FOREACH_BINARY_MODE(cb) \ #define FOREACH_BINARY_MODE(cb) \
cb(ADD) \ cb(ADD) \
...@@ -422,7 +430,8 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { ...@@ -422,7 +430,8 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
run_mlir_mode<TypeParam, 2>(cn); run_mlir_mode<TypeParam, 2>(cn);
} }
///////////////////////// ternary /////////////////////////////// /* ===================== TestJITMlirTenaryElemwise ===================== */
// clang-format off // clang-format off
#define FOREACH_TERNARY_MODE(cb) \ #define FOREACH_TERNARY_MODE(cb) \
cb(COND_LEQ_MOV) \ cb(COND_LEQ_MOV) \
...@@ -456,6 +465,81 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { ...@@ -456,6 +465,81 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) {
#undef SKIP_MODE #undef SKIP_MODE
/* ===================== TestJITMlirTypeCvt ===================== */
template <typename itype, typename otype>
void run_typecvt(CompNode cn) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<itype, RandomDistribution::UNIFORM> gen(-10, 10);
auto host_x = gen({23, 42}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = opr::TypeCvt::make(x, otype());
auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
for (auto i : get_rev_topo_order(y)) {
if (!i->template same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}
auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
};
#define add_typecvt_gtest(itype, otype) \
TEST(TestJITMlirTypeCvt, itype##_to_##otype) { \
run_typecvt<dtype::itype, dtype::otype>(CompNode::load("cpu0")); \
} \
TEST(TestJITMlirTypeCvt, itype##_to_##otype##_GPU) { \
REQUIRE_GPU(1); \
run_typecvt<dtype::itype, dtype::otype>(CompNode::load("gpu0")); \
}
#if !MEGDNN_DISABLE_FLOAT16
// TODO: the support for f16 and bf16 is currently not complete in mlir
// FPExtOp
// add_typecvt_gtest(Float16, Float32);
// add_typecvt_gtest(BFloat16, Float32);
// add_typecvt_gtest(Float16, BFloat16);
// FPTruncOp
// add_typecvt_gtest(Float32, Float16);
// add_typecvt_gtest(Float32, BFloat16);
// add_typecvt_gtest(Float16, BFloat16);
#endif
// FPToSIOp
add_typecvt_gtest(Float32, Int8);
add_typecvt_gtest(Float32, Int16);
add_typecvt_gtest(Float32, Int32);
// FPToUIOp
add_typecvt_gtest(Float32, Uint8);
// SIToFPOp
add_typecvt_gtest(Int8, Float32);
add_typecvt_gtest(Int16, Float32);
add_typecvt_gtest(Int32, Float32);
// UIToFPOp
add_typecvt_gtest(Uint8, Float32);
#undef add_typecvt_gtest
#endif // MGB_JIT_MLIR #endif // MGB_JIT_MLIR
#endif // MGB_JIT #endif // MGB_JIT
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s // RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s
func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () { func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () {
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : %0 = "mgb.Elemwise"(%lhs, %rhs) {name = "add.f", mode = 16 : i32} :
(memref<2xf32>, memref<2xf32>) -> memref<2xf32> (memref<2xf32>, memref<2xf32>) -> memref<2xf32>
"mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () "mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> ()
mgb.return mgb.return
...@@ -24,7 +24,7 @@ func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> ...@@ -24,7 +24,7 @@ func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) ->
// CHECK: } // CHECK: }
func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () { func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () {
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : %0 = "mgb.Elemwise"(%lhs, %rhs) {name = "add.f", mode = 16 : i32} :
(memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32> (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32>
"mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> () "mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> ()
mgb.return mgb.return
...@@ -55,4 +55,4 @@ func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: m ...@@ -55,4 +55,4 @@ func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: m
// CHECK: } // CHECK: }
// CHECK: dealloc %0 : memref<4x3x64x64xf32> // CHECK: dealloc %0 : memref<4x3x64x64xf32>
// CHECK: return // CHECK: return
// CHECK: } // CHECK: }
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册