提交 d2910f7e 编写于 作者: M Megvii Engine Team

fix(mgb/jit): add bind_shape feature to MLIRCompiler

GitOrigin-RevId: bec6796fbfe6dc63444dddf8bf8f6d230bac1fb3
上级 9de1ea6a
......@@ -33,7 +33,7 @@ public:
MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU);
Property property() const override {
using F = Property::Flag;
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM,
return Property{F::BIND_NDIM | F::BIND_SHAPE,
JITFeatureBits::DIMSHUFFLE, 64};
}
......
......@@ -197,6 +197,41 @@ void run_mlir_broadcast(CompNode cn) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}
void run_mlir_different_shape(CompNode cn) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Float32> gen;
auto run = [&](TensorShape tshp) {
auto host_x = gen(tshp, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = x * 2;
auto ig_gen =
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
for (auto i : get_rev_topo_order(y)) {
if (!i->same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}
auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
};
run({23, 42});
run({16, 31});
run({32, 56});
run({10});
}
struct MlirTestOpt {
float low;
float high;
......@@ -297,6 +332,7 @@ TEST(TestJITMlirCodeGen, Basic) {
auto cn = CompNode::load("cpu0");
run_mlir(cn);
run_mlir_broadcast(cn);
run_mlir_different_shape(cn);
}
TEST(TestJITMlirCodeGen, BasicGPU) {
......@@ -304,6 +340,7 @@ TEST(TestJITMlirCodeGen, BasicGPU) {
auto cn = CompNode::load("gpu0");
run_mlir(cn);
run_mlir_broadcast(cn);
run_mlir_different_shape(cn);
}
/* ===================== TestJITMlirUnaryElemwise ===================== */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册