diff --git a/src/jit/impl/mlir/compiler.h b/src/jit/impl/mlir/compiler.h index 246b748eec4ec536ed9a889638a8a2b23cf52036..a48ce79f04bdae1009aaea47328ba94ddbcb8bec 100644 --- a/src/jit/impl/mlir/compiler.h +++ b/src/jit/impl/mlir/compiler.h @@ -33,7 +33,7 @@ public: MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU); Property property() const override { using F = Property::Flag; - return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, + return Property{F::BIND_NDIM | F::BIND_SHAPE, JITFeatureBits::DIMSHUFFLE, 64}; } diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 1e07c87b9f40d73672b3fd78b1e96e55cb83f3cc..e36a51725f5ffe9d71091b246e9c2a4fa601f9de 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -197,6 +197,41 @@ void run_mlir_broadcast(CompNode cn) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); } +void run_mlir_different_shape(CompNode cn) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + + auto run = [&](TensorShape tshp) { + auto host_x = gen(tshp, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto y = x * 2; + auto ig_gen = + std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); + }; + + run({23, 42}); + run({16, 31}); + run({32, 56}); + run({10}); +} + struct MlirTestOpt { float low; float high; @@ -297,6 +332,7 @@ TEST(TestJITMlirCodeGen, Basic) { auto cn = CompNode::load("cpu0"); run_mlir(cn); run_mlir_broadcast(cn); + run_mlir_different_shape(cn); } TEST(TestJITMlirCodeGen, BasicGPU) { @@ -304,6 +340,7 @@ TEST(TestJITMlirCodeGen, BasicGPU) { auto cn = CompNode::load("gpu0"); run_mlir(cn); run_mlir_broadcast(cn); + run_mlir_different_shape(cn); } /* ===================== TestJITMlirUnaryElemwise ===================== */