From d2910f7ef5813bbaa991f1ea0665d40af40b9fe0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 19 Dec 2020 15:34:55 +0800 Subject: [PATCH] fix(mgb/jit): add bind_shape feature to MLIRCompiler GitOrigin-RevId: bec6796fbfe6dc63444dddf8bf8f6d230bac1fb3 --- src/jit/impl/mlir/compiler.h | 2 +- src/jit/test/codegen.cpp | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/jit/impl/mlir/compiler.h b/src/jit/impl/mlir/compiler.h index 246b748ee..a48ce79f0 100644 --- a/src/jit/impl/mlir/compiler.h +++ b/src/jit/impl/mlir/compiler.h @@ -33,7 +33,7 @@ public: MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU); Property property() const override { using F = Property::Flag; - return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, + return Property{F::BIND_NDIM | F::BIND_SHAPE, JITFeatureBits::DIMSHUFFLE, 64}; } diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 1e07c87b9..e36a51725 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -197,6 +197,41 @@ void run_mlir_broadcast(CompNode cn) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); } +void run_mlir_different_shape(CompNode cn) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + + auto run = [&](TensorShape tshp) { + auto host_x = gen(tshp, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto y = x * 2; + auto ig_gen = + std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); + }; + + run({23, 42}); + run({16, 31}); + run({32, 56}); + run({10}); +} + struct MlirTestOpt { float low; float high; @@ -297,6 +332,7 @@ TEST(TestJITMlirCodeGen, Basic) { auto cn = CompNode::load("cpu0"); run_mlir(cn); run_mlir_broadcast(cn); + run_mlir_different_shape(cn); } TEST(TestJITMlirCodeGen, BasicGPU) { @@ -304,6 +340,7 @@ TEST(TestJITMlirCodeGen, BasicGPU) { auto cn = CompNode::load("gpu0"); run_mlir(cn); run_mlir_broadcast(cn); + run_mlir_different_shape(cn); } /* ===================== TestJITMlirUnaryElemwise ===================== */ -- GitLab