提交 0f739c11 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix NormalizeArithChainPass to process sub-div chains

GitOrigin-RevId: d71debbfded8520209486e80eaa09adad4818225
上级 c53abcdf
......@@ -322,6 +322,15 @@ NormalizeArithChainPass::Impl::AddTrait::extract_coeff(
return AbstractOpr::make_coeff(
i1.node(), i0v->get_cast<dt_max_float>());
}
if (mode == Mode::TRUE_DIV) {
SymbolVar i0 = opr->input(0), i1 = opr->input(1),
i1r = opr::powf(i1, -1);
auto i1rv = i1r.as_immutable_scalar_require_shape();
if (!i1rv.valid())
return None;
return AbstractOpr::make_coeff(
i0.node(), i1rv->get_cast<dt_max_float>());
}
return None;
}
......
......@@ -548,6 +548,19 @@ TEST(TestNormalizeArithChainPass, PowcCExpand2) {
graph->compile({make_callback_copy(grad, host_g)}));
}
TEST_PASS(NormalizeArithChainPass, SubDiv) {
auto x = mkvar("x"), y = mkvar("y"), z = mkvar("z"),
a0_ = x - y / 2.f,
a1 = x + (-0.5f) * y,
b0_ = x - ((y - (z / 5.f)) / 2.f),
b1 = x + (-0.5f) * y + 0.1f * z;
SymbolVar a0, b0;
unpack_vector(run_opt({a0_, b0_}), a0, b0);
EXPECT_EQ(a1, a0);
EXPECT_EQ(b1, b0);
}
TEST_PASS(ReorderArithChainPass, 0) {
auto chk = [this](SymbolVar inp, SymbolVar expect) {
check(expect, inp, gopt::ConstVarType::IMMUTABLE_AND_PARAM);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册