提交 23437864 编写于 作者: M Megvii Engine Team

fix(mgb/jit): mlir doesn't support broadcast

GitOrigin-RevId: 08bfc4c34aa378b61835269a0c0de90d04602bb6
上级 f87bba68
......@@ -291,8 +291,27 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
cond_cn = opr->output(0)->comp_node() ==
ig_gen->output()->comp_node(),
cond_shp = check_shape(opr, ig_gen),
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input;
if (cond_readers && cond_cn && cond_shp && cond_nr_inp) {
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input,
cond_mlir_specific = true;
#if MGB_JIT_MLIR
//! FIXME mlir does't support broadcast currently.
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
if (!strcmp(backend, "MLIR")) {
for (VarNode* var : opr->input()) {
if (!SymbolVar{var}.as_immutable_scalar().valid()) {
if (opr->node_prop().dep_map().at(var) &
DepType::DEV_VALUE) {
if (!var->shape().eq_shape(opr->output(0)->shape())) {
cond_mlir_specific = false;
}
}
}
}
}
#endif
if (cond_readers && cond_cn && cond_shp && cond_nr_inp &&
cond_mlir_specific) {
ig_gen->add_opr(opr);
} else {
if (opr->same_type<opr::Dimshuffle>()) {
......@@ -344,7 +363,10 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
}
//! As MLIR backend has some contraints
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
const char* backend = MGB_GETENV("MGB_JIT_BACKEND");
if (!backend) {
backend = "DEFAULT";
}
// float elemwise
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
bool ret = true;
......
......@@ -222,9 +222,6 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module,
std::unique_ptr<Executable> MLIRCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
MGB_MARK_USED_VAR(graph);
MGB_MARK_USED_VAR(args);
mlir::MLIRContext ctx;
ctx.printStackTraceOnDiagnostic(true);
ctx.printOpOnDiagnostic(true);
......
......@@ -19,7 +19,7 @@
* implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights
* All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights
* reserved.
*
*/
......
......@@ -19,7 +19,7 @@
namespace mgb {
namespace jit {
inline bool is_elemwise_float(const mlir::Type& dt) {
inline const bool is_elemwise_float(const mlir::Type& dt) {
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) {
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) {
return true;
......
......@@ -1553,6 +1553,48 @@ TEST(TestJITExecutor, GradBehavior) {
}
}
#if MGB_JIT_MLIR
void run_mlir(CompNode cn) {
set_backend(Backend::MLIR);
HostTensorGenerator<> gen;
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn),
host_x2 = gen({1, 42}, cn), host_x3 = gen({23, 42}, cn),
host_x4 = gen({1, 42}, cn), host_x5 = gen({23, 1}, cn);
auto make_dst = [&](ComputingGraph& graph) {
auto a = opr::Host2DeviceCopy::make(graph, host_x0),
b = opr::Host2DeviceCopy::make(graph, host_x1),
c = opr::Host2DeviceCopy::make(graph, host_x2),
d = opr::Host2DeviceCopy::make(graph, host_x3),
e = opr::Host2DeviceCopy::make(graph, host_x4);
return a + opr::max(b, c) + opr::max(d, e);
};
HostTensorND host_y1, host_y2;
auto funcs = make_func_pair(host_y1, host_y2, make_dst, 2);
funcs.first->execute();
funcs.second->execute();
MGB_ASSERT_TENSOR_EQ(host_y1, host_y2);
JITExecutor* jit;
unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit);
ASSERT_EQ(2u, find_oprs<opr::Elemwise>(*funcs.second).size());
ASSERT_EQ(3u, jit->input().size());
}
TEST(TestJITExecutor, TestJITMlirFusion) {
run_mlir(CompNode::load("cpu0"));
}
TEST(TestJITExecutor, TestJITMlirFusionGpu) {
REQUIRE_GPU(1);
run_mlir(CompNode::load("gpu0"));
}
#endif // MGB_JIT_MLIR
#endif // MGB_JIT
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册