From 23437864f95a28154308c42b88a37c3309a54cf7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 10 Sep 2020 19:30:32 +0800 Subject: [PATCH] fix(mgb/jit): mlir doesn't support broadcast GitOrigin-RevId: 08bfc4c34aa378b61835269a0c0de90d04602bb6 --- src/jit/impl/fusion_pass.cpp | 28 +++++++++++-- src/jit/impl/mlir/compiler.cpp | 3 -- .../ir/create_gpu_kernel_outlining_pass.cpp | 2 +- src/jit/impl/mlir/ir/types.h | 2 +- src/jit/test/fusion.cpp | 42 +++++++++++++++++++ 5 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/jit/impl/fusion_pass.cpp b/src/jit/impl/fusion_pass.cpp index 8a2d722ca..8f7284709 100644 --- a/src/jit/impl/fusion_pass.cpp +++ b/src/jit/impl/fusion_pass.cpp @@ -291,8 +291,27 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { cond_cn = opr->output(0)->comp_node() == ig_gen->output()->comp_node(), cond_shp = check_shape(opr, ig_gen), - cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input; - if (cond_readers && cond_cn && cond_shp && cond_nr_inp) { + cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input, + cond_mlir_specific = true; + +#if MGB_JIT_MLIR + //! FIXME mlir does't support broadcast currently. + auto backend = MGB_GETENV("MGB_JIT_BACKEND"); + if (!strcmp(backend, "MLIR")) { + for (VarNode* var : opr->input()) { + if (!SymbolVar{var}.as_immutable_scalar().valid()) { + if (opr->node_prop().dep_map().at(var) & + DepType::DEV_VALUE) { + if (!var->shape().eq_shape(opr->output(0)->shape())) { + cond_mlir_specific = false; + } + } + } + } + } +#endif + if (cond_readers && cond_cn && cond_shp && cond_nr_inp && + cond_mlir_specific) { ig_gen->add_opr(opr); } else { if (opr->same_type()) { @@ -344,7 +363,10 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { } //! As MLIR backend has some contraints - auto backend = MGB_GETENV("MGB_JIT_BACKEND"); + const char* backend = MGB_GETENV("MGB_JIT_BACKEND"); + if (!backend) { + backend = "DEFAULT"; + } // float elemwise if (auto elem = gopt::try_cast_as_op(opr)) { bool ret = true; diff --git a/src/jit/impl/mlir/compiler.cpp b/src/jit/impl/mlir/compiler.cpp index 5cfd1465e..1be70411f 100644 --- a/src/jit/impl/mlir/compiler.cpp +++ b/src/jit/impl/mlir/compiler.cpp @@ -222,9 +222,6 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module, std::unique_ptr MLIRCompiler::do_compile( const InternalGraph& graph, const JITExecutor::Args& args) { - MGB_MARK_USED_VAR(graph); - MGB_MARK_USED_VAR(args); - mlir::MLIRContext ctx; ctx.printStackTraceOnDiagnostic(true); ctx.printOpOnDiagnostic(true); diff --git a/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp b/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp index ec3666d88..d7c89e426 100644 --- a/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp +++ b/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp @@ -19,7 +19,7 @@ * implied. * * This file has been modified by Megvii ("Megvii Modifications"). - * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights + * All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights * reserved. * */ diff --git a/src/jit/impl/mlir/ir/types.h b/src/jit/impl/mlir/ir/types.h index 548b5db42..9390c30f1 100644 --- a/src/jit/impl/mlir/ir/types.h +++ b/src/jit/impl/mlir/ir/types.h @@ -19,7 +19,7 @@ namespace mgb { namespace jit { -inline bool is_elemwise_float(const mlir::Type& dt) { +inline const bool is_elemwise_float(const mlir::Type& dt) { if (auto cast = dt.dyn_cast_or_null()) { if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { return true; diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index 94ad73561..17dae4b12 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -1553,6 +1553,48 @@ TEST(TestJITExecutor, GradBehavior) { } } +#if MGB_JIT_MLIR + +void run_mlir(CompNode cn) { + set_backend(Backend::MLIR); + + HostTensorGenerator<> gen; + auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn), + host_x2 = gen({1, 42}, cn), host_x3 = gen({23, 42}, cn), + host_x4 = gen({1, 42}, cn), host_x5 = gen({23, 1}, cn); + + auto make_dst = [&](ComputingGraph& graph) { + auto a = opr::Host2DeviceCopy::make(graph, host_x0), + b = opr::Host2DeviceCopy::make(graph, host_x1), + c = opr::Host2DeviceCopy::make(graph, host_x2), + d = opr::Host2DeviceCopy::make(graph, host_x3), + e = opr::Host2DeviceCopy::make(graph, host_x4); + return a + opr::max(b, c) + opr::max(d, e); + }; + HostTensorND host_y1, host_y2; + auto funcs = make_func_pair(host_y1, host_y2, make_dst, 2); + + funcs.first->execute(); + funcs.second->execute(); + MGB_ASSERT_TENSOR_EQ(host_y1, host_y2); + + JITExecutor* jit; + unpack_vector(find_oprs(*funcs.second), jit); + ASSERT_EQ(2u, find_oprs(*funcs.second).size()); + ASSERT_EQ(3u, jit->input().size()); +} + +TEST(TestJITExecutor, TestJITMlirFusion) { + run_mlir(CompNode::load("cpu0")); +} + +TEST(TestJITExecutor, TestJITMlirFusionGpu) { + REQUIRE_GPU(1); + run_mlir(CompNode::load("gpu0")); +} + +#endif // MGB_JIT_MLIR + #endif // MGB_JIT // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab