From 9682db982084900084bb726627fa2a2ec6c06a23 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 10 Oct 2020 14:28:10 +0800 Subject: [PATCH] feat(mgb): add jit mlir elemwise broadcast GitOrigin-RevId: 89d5e2f91eab46bc66fea014cf9170e49b5dfc4e --- src/jit/impl/fusion_pass.cpp | 16 ---- src/jit/impl/mlir/executable_cuda.cpp | 14 +-- src/jit/impl/mlir/ir/common.cpp | 47 ++++++++++ src/jit/impl/mlir/ir/common.h | 14 ++- src/jit/impl/mlir/ir/lower_to_affine_pass.cpp | 49 ++++++---- src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp | 90 ++++++++++++++++--- src/jit/impl/mlir/ir/utils.cpp | 1 - src/jit/impl/mlir/mlir_gen.cpp | 4 +- src/jit/include/megbrain/jit/mlir/ir/utils.h | 1 - src/jit/test/codegen.cpp | 43 ++++++++- src/jit/test/fusion.cpp | 4 +- 11 files changed, 219 insertions(+), 64 deletions(-) diff --git a/src/jit/impl/fusion_pass.cpp b/src/jit/impl/fusion_pass.cpp index 00ea7196..6355a9be 100644 --- a/src/jit/impl/fusion_pass.cpp +++ b/src/jit/impl/fusion_pass.cpp @@ -294,22 +294,6 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input, cond_mlir_specific = true; -#if MGB_JIT_MLIR - //! FIXME mlir does't support broadcast currently. - auto backend = MGB_GETENV("MGB_JIT_BACKEND"); - if (backend && !strcmp(backend, "MLIR")) { - for (VarNode* var : opr->input()) { - if (!SymbolVar{var}.as_immutable_scalar().valid()) { - if (opr->node_prop().dep_map().at(var) & - DepType::DEV_VALUE) { - if (!var->shape().eq_shape(opr->output(0)->shape())) { - cond_mlir_specific = false; - } - } - } - } - } -#endif if (cond_readers && cond_cn && cond_shp && cond_nr_inp && cond_mlir_specific) { ig_gen->add_opr(opr); diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp index 6448ca6b..e7c1845f 100644 --- a/src/jit/impl/mlir/executable_cuda.cpp +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -57,23 +57,23 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, } }; for (const auto& arg : args.inputs) { - set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); + set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); } int64_t nr_elements = 0; for (const auto& arg : args.outputs) { if (nr_elements == 0) { - nr_elements = arg.layout.total_nr_elems(); + nr_elements = arg.from->layout().total_nr_elems(); } else { mgb_assert(static_cast(nr_elements) == arg.layout.total_nr_elems(), "The number of elements of outputs mismatch, expected: " "%zu got: %zu(%s)", static_cast(nr_elements), - arg.layout.total_nr_elems(), - arg.layout.to_string().c_str()); + arg.from->layout().total_nr_elems(), + arg.from->layout().to_string().c_str()); } - set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); + set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); } const CompNodeEnv& env = CompNodeEnv::from_comp_node(fusion_opr->comp_node()); @@ -134,8 +134,8 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, mgb_assert(fusion_opr->args().outputs.size() == 1, "Currently only support 1 outputs, got %zu", fusion_opr->args().outputs.size()); - int out_dim = fusion_opr->args().outputs[0].layout.ndim; - DType dtype = fusion_opr->args().outputs[0].layout.dtype; + int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; + DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; #define cb_outdim(_ndim, _dtype) \ if (_ndim == out_dim) { \ setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ diff --git a/src/jit/impl/mlir/ir/common.cpp b/src/jit/impl/mlir/ir/common.cpp index 58d8c0a0..085bbff4 100644 --- a/src/jit/impl/mlir/ir/common.cpp +++ b/src/jit/impl/mlir/ir/common.cpp @@ -14,8 +14,10 @@ #if MGB_JIT && MGB_JIT_MLIR #include "./common.h" +#include "megbrain/jit/mlir/ir/utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include using namespace mgb; using namespace jit; @@ -28,9 +30,11 @@ cb(add, AddFOp); cb(sub, SubFOp); cb(mul, MulFOp); cb(div, DivFOp); +cb(divI, SignedDivIOp); cb(mod, RemFOp); cb(bit_and, AndOp); cb(bit_or, OrOp); +cb(modI, SignedRemIOp); #undef cb #define cb(name, mode) \ @@ -62,6 +66,11 @@ mlir::Value ValueBuilderHelper::const_val(float val) { m_builder.getF32FloatAttr(val)); } +mlir::Value ValueBuilderHelper::constI(int32_t val) { + return m_builder.create(m_location, + m_builder.getIndexAttr(val)); +} + #define cb(name, op) \ mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ return m_builder.create(m_location, lhs); \ @@ -97,6 +106,44 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, false_val); } +mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, + const mlir::Value& val, + const megdnn::TensorLayout& layout) { + auto type = val.getType().cast(); + mgb_assert(type, "currently only support MemRefType"); + std::vector exprs; + for (int i = 0; i < type.getRank(); ++i) { + if (layout[i] == 1) { + exprs.push_back(builder.getAffineConstantExpr(0)); + } else { + exprs.push_back(builder.getAffineDimExpr(i)); + } + } + auto map = mlir::AffineMap::get(type.getRank(), 0, exprs, + builder.getContext()); + return map; +} + +mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, + const mlir::Location& loc, + const mlir::Value& val, + const mlir::ValueRange& index, + const megdnn::TensorLayout& dst) { + if (val.getType().isa()) { + auto type = val.getType().cast(); + megdnn::TensorLayout src_layout = mlir_type_to_layout(type); + src_layout.init_contiguous_stride(); + if (src_layout.eq_shape(dst)) { + return builder.create(loc, val, index); + } else { + auto lhs_map = get_affinemap(builder, val, src_layout); + return builder.create(loc, val, lhs_map, index); + } + } else { + return val; + } +} + #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/common.h b/src/jit/impl/mlir/ir/common.h index 7ea03cad..edbba330 100644 --- a/src/jit/impl/mlir/ir/common.h +++ b/src/jit/impl/mlir/ir/common.h @@ -14,7 +14,7 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR - +#include "megbrain/tensor.h" #include #include #include @@ -39,9 +39,11 @@ public: cb(sub); cb(mul); cb(div); + cb(divI); cb(max); cb(min); cb(mod); + cb(modI); cb(gt); cb(ge); cb(lt); @@ -51,6 +53,7 @@ public: cb(bit_or); #undef cb mlir::Value const_val(float val); + mlir::Value constI(int32_t val); #define cb(name) \ mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ @@ -89,6 +92,15 @@ mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc, } } +mlir::AffineMap get_affinemap(mlir::OpBuilder& builder, const mlir::Value& val, + const TensorLayout& layout); + +mlir::Value get_affine_load_op(mlir::OpBuilder& builder, + const mlir::Location& loc, + const mlir::Value& val, + const mlir::ValueRange& index, + const TensorLayout& dst); + } // namespace jit } // namespace mgb diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp index 32c7bcff..dee486a8 100644 --- a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -42,8 +42,8 @@ void lower_op_to_loops(Operation* op, ValueRange operands, auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); - SmallVector lower_bounds(memref_type.getRank(), 0); - SmallVector steps(memref_type.getRank(), 1); + llvm::SmallVector lower_bounds(memref_type.getRank(), 0); + llvm::SmallVector steps(memref_type.getRank(), 1); buildAffineLoopNest( rewriter, loc, lower_bounds, memref_type.getShape(), steps, [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { @@ -96,17 +96,23 @@ struct BinaryOpLowering : public ConversionPattern { Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); + auto dst_memref_type = (*op->result_type_begin()).cast(); + megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); + dst_layout.init_contiguous_stride(); lower_op_to_loops( op, operands, rewriter, - [loc](OpBuilder& builder, ValueRange memref_operands, - ValueRange loop_ivs) { + [dst_layout, loc, this](OpBuilder& builder, + ValueRange memref_operands, + ValueRange loop_ivs) { typename Op::Adaptor binary_adaptor(memref_operands); LoweredOp lower_op; - auto loaded_lhs = get_operand( - builder, loc, binary_adaptor.lhs(), loop_ivs); - auto loaded_rhs = get_operand( - builder, loc, binary_adaptor.rhs(), loop_ivs); + auto loaded_lhs = get_affine_load_op(builder, loc, + binary_adaptor.lhs(), + loop_ivs, dst_layout); + auto loaded_rhs = get_affine_load_op(builder, loc, + binary_adaptor.rhs(), + loop_ivs, dst_layout); return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); }); @@ -128,19 +134,26 @@ struct TernaryOpLowering : public ConversionPattern { Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); + auto dst_memref_type = (*op->result_type_begin()).cast(); + megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); + dst_layout.init_contiguous_stride(); lower_op_to_loops( op, operands, rewriter, - [loc](OpBuilder& builder, ValueRange memref_operands, - ValueRange loop_ivs) { + [dst_layout, loc](OpBuilder& builder, + ValueRange memref_operands, + ValueRange loop_ivs) { typename Op::Adaptor ternary_adaptor(memref_operands); LoweredOp lower_op; - auto loaded_x = get_operand( - builder, loc, ternary_adaptor.x(), loop_ivs); - auto loaded_y = get_operand( - builder, loc, ternary_adaptor.y(), loop_ivs); - auto loaded_z = get_operand( - builder, loc, ternary_adaptor.z(), loop_ivs); + auto loaded_x = get_affine_load_op(builder, loc, + ternary_adaptor.x(), + loop_ivs, dst_layout); + auto loaded_y = get_affine_load_op(builder, loc, + ternary_adaptor.y(), + loop_ivs, dst_layout); + auto loaded_z = get_affine_load_op(builder, loc, + ternary_adaptor.z(), + loop_ivs, dst_layout); return lower_op(builder, loc, {loaded_x, loaded_y, loaded_z}); @@ -166,8 +179,8 @@ struct AssignOpLowering : public ConversionPattern { auto memref_type = operands[0].getType().cast(); AssignOpAdaptor assign_adaptor(operands); - SmallVector lower_bounds(memref_type.getRank(), 0); - SmallVector steps(memref_type.getRank(), 1); + llvm::SmallVector lower_bounds(memref_type.getRank(), 0); + llvm::SmallVector steps(memref_type.getRank(), 1); buildAffineLoopNest( rewriter, loc, lower_bounds, memref_type.getShape(), steps, [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { diff --git a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp index e0b35b66..03683872 100644 --- a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp @@ -52,6 +52,54 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { return index; } +megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) { + auto func_op = launch_op.getParentOfType(); + mgb_assert(func_op, "Unexpexted launch op."); + for (auto block_iter = func_op.rbegin(); block_iter != func_op.rend(); + block_iter++) { + for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); + op_iter++) { + auto op = llvm::dyn_cast_or_null(&(*op_iter)); + if (op && op.getNumOperands() > 0) { + return mlir_type_to_layout(*(op.operand_type_begin())); + } + } + } + mgb_throw(MegBrainError, "Unexpexted launch op."); +} + +std::vector get_multidim_tid(ConversionPatternRewriter& rewriter, + const Location& loc, + const mlir::Value& val, + const megdnn::TensorLayout& dst) { + Value index = get_tid(rewriter, loc); + + auto type = val.getType().dyn_cast_or_null(); + if (type) { + ValueBuilderHelper helper(rewriter, loc); + std::vector idxs; + idxs.resize(dst.ndim); + mlir::Value dim_index = index; + for (int i = dst.ndim - 1; i >= 0; i--) { + auto cur_index = helper.modI(dim_index, helper.constI(dst[i])); + idxs[i] = cur_index; + dim_index = helper.divI(dim_index, helper.constI(dst[i])); + } + + megdnn::TensorLayout src_layout = mlir_type_to_layout(type); + src_layout.init_contiguous_stride(); + for (int i = 0; i < type.getRank(); ++i) { + if (src_layout[i] == 1) { + idxs[i] = helper.constI(0); + } + } + return idxs; + } else { + return {index}; + } + +} + template struct UnaryOpLowering : public ConversionPattern { UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) @@ -66,7 +114,9 @@ struct UnaryOpLowering : public ConversionPattern { typename Op::Adaptor binary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); - auto index = get_tid(rewriter, loc); + auto dst_layout = output_layout(m_launch_op); + auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), + dst_layout); auto loaded_lhs = get_operand(rewriter, loc, binary_adaptor.lhs(), index); @@ -99,11 +149,15 @@ struct BinaryOpLowering : public ConversionPattern { typename Op::Adaptor binary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); - auto index = get_tid(rewriter, loc); - auto loaded_lhs = - get_operand(rewriter, loc, binary_adaptor.lhs(), index); - auto loaded_rhs = - get_operand(rewriter, loc, binary_adaptor.rhs(), index); + auto dst_layout = output_layout(m_launch_op); + auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), + dst_layout); + auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(), + dst_layout); + auto loaded_lhs = get_operand(rewriter, loc, + binary_adaptor.lhs(), lhs_index); + auto loaded_rhs = get_operand(rewriter, loc, + binary_adaptor.rhs(), rhs_index); LoweredOp lower_op; @@ -135,13 +189,19 @@ struct TernaryOpLowering : public ConversionPattern { typename Op::Adaptor ternary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); - auto index = get_tid(rewriter, loc); - auto loaded_x = - get_operand(rewriter, loc, ternary_adaptor.x(), index); - auto loaded_y = - get_operand(rewriter, loc, ternary_adaptor.y(), index); - auto loaded_z = - get_operand(rewriter, loc, ternary_adaptor.z(), index); + auto dst_layout = output_layout(m_launch_op); + auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(), + dst_layout); + auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(), + dst_layout); + auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(), + dst_layout); + auto loaded_x = get_operand(rewriter, loc, ternary_adaptor.x(), + index_x); + auto loaded_y = get_operand(rewriter, loc, ternary_adaptor.y(), + index_y); + auto loaded_z = get_operand(rewriter, loc, ternary_adaptor.z(), + index_z); LoweredOp lower_op; @@ -242,7 +302,9 @@ struct AssignOpLowering : public ConversionPattern { AssignOpAdaptor assign_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); - auto index = get_tid(rewriter, loc); + auto dst_layout = output_layout(m_launch_op); + auto index = get_multidim_tid(rewriter, loc, assign_adaptor.rhs(), + dst_layout); auto loaded_lhs = get_operand(rewriter, loc, assign_adaptor.lhs(), index); diff --git a/src/jit/impl/mlir/ir/utils.cpp b/src/jit/impl/mlir/ir/utils.cpp index afdb25e8..794d7b6f 100644 --- a/src/jit/impl/mlir/ir/utils.cpp +++ b/src/jit/impl/mlir/ir/utils.cpp @@ -98,7 +98,6 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, for (size_t i = 0; i < layout.ndim; i++) { shape.push_back(layout[i]); } - switch (layout.dtype.enumv()) { case megdnn::DTypeEnum::Float32: return mlir::MemRefType::get(shape, builder.getF32Type()); diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp index ab062661..b4d7f86d 100644 --- a/src/jit/impl/mlir/mlir_gen.cpp +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -73,10 +73,10 @@ private: m_symbol_table); std::vector func_args; for (auto&& arg : args.inputs) { - func_args.push_back(get_type(arg.layout)); + func_args.push_back(get_type(arg.from->layout())); } for (auto&& arg : args.outputs) { - func_args.push_back(get_type(arg.layout)); + func_args.push_back(get_type(arg.from->layout())); } //! the last arg is nr_elements func_args.push_back(m_builder.getIndexType()); diff --git a/src/jit/include/megbrain/jit/mlir/ir/utils.h b/src/jit/include/megbrain/jit/mlir/ir/utils.h index b9fb5fb5..710dd57e 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/utils.h +++ b/src/jit/include/megbrain/jit/mlir/ir/utils.h @@ -44,7 +44,6 @@ megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); megdnn::DType mlir_type_to_dtype(mlir::Type type); mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, mlir::Builder& builder); - } // namespace jit } // namespace mgb diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 0c676820..0d3aaeb1 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -130,8 +130,8 @@ void run_mlir(CompNode cn) { auto graph = ComputingGraph::make(); HostTensorGenerator gen; - auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), - host_x2 = gen({23, 42}, cn), host_x3 = gen({23, 42}, cn); + auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn), + host_x2 = gen({23, 42}, cn); auto a = opr::Host2DeviceCopy::make(*graph, host_x0), b = opr::Host2DeviceCopy::make(*graph, host_x1), @@ -159,6 +159,43 @@ void run_mlir(CompNode cn) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); } +void run_mlir_broadcast(CompNode cn) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + + auto host_x0 = gen({10, 20, 5, 6}, cn), host_x1 = gen({1, 20, 1, 1}, cn), + host_x2 = gen({10, 1, 5, 1}, cn), host_x3 = gen({10, 1, 1, 1}, cn); + + auto a = opr::Host2DeviceCopy::make(*graph, host_x0), + b = opr::Host2DeviceCopy::make(*graph, host_x1), + c = opr::Host2DeviceCopy::make(*graph, host_x2), + d = opr::Host2DeviceCopy::make(*graph, host_x3); + + auto y = + opr::Elemwise::make({a, b, c}, opr::Elemwise::Mode::FUSE_MUL_ADD3) + + opr::Elemwise::make({d}, opr::Elemwise::Mode::ABS) - 0.3f; + + auto ig_gen = + std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); +} + struct MlirTestOpt { float low; float high; @@ -252,12 +289,14 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { TEST(TestJITMlirCodeGen, Basic) { auto cn = CompNode::load("cpu0"); run_mlir(cn); + run_mlir_broadcast(cn); } TEST(TestJITMlirCodeGen, BasicGPU) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); run_mlir(cn); + run_mlir_broadcast(cn); } ///////////////////////// unary /////////////////////////////// diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index 17dae4b1..b60a8ef4 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -1580,8 +1580,8 @@ void run_mlir(CompNode cn) { JITExecutor* jit; unpack_vector(find_oprs(*funcs.second), jit); - ASSERT_EQ(2u, find_oprs(*funcs.second).size()); - ASSERT_EQ(3u, jit->input().size()); + ASSERT_EQ(0u, find_oprs(*funcs.second).size()); + ASSERT_EQ(5u, jit->input().size()); } TEST(TestJITExecutor, TestJITMlirFusion) { -- GitLab