diff --git a/src/gopt/impl/basic_arith/chain.cpp b/src/gopt/impl/basic_arith/chain.cpp index 654b3a464ed9a682a8bfb825fea37c259cdc1496..bad9547b57e2574032560d8a077dee212c0edd8e 100644 --- a/src/gopt/impl/basic_arith/chain.cpp +++ b/src/gopt/impl/basic_arith/chain.cpp @@ -322,6 +322,15 @@ NormalizeArithChainPass::Impl::AddTrait::extract_coeff( return AbstractOpr::make_coeff( i1.node(), i0v->get_cast()); } + if (mode == Mode::TRUE_DIV) { + SymbolVar i0 = opr->input(0), i1 = opr->input(1), + i1r = opr::powf(i1, -1); + auto i1rv = i1r.as_immutable_scalar_require_shape(); + if (!i1rv.valid()) + return None; + return AbstractOpr::make_coeff( + i0.node(), i1rv->get_cast()); + } return None; } diff --git a/src/gopt/test/basic_arith.cpp b/src/gopt/test/basic_arith.cpp index 82e5708c4c9980f53a80f12e09298d5864f53c49..7f2160668f951455622e1969cd5103e985fe5229 100644 --- a/src/gopt/test/basic_arith.cpp +++ b/src/gopt/test/basic_arith.cpp @@ -548,6 +548,19 @@ TEST(TestNormalizeArithChainPass, PowcCExpand2) { graph->compile({make_callback_copy(grad, host_g)})); } +TEST_PASS(NormalizeArithChainPass, SubDiv) { + auto x = mkvar("x"), y = mkvar("y"), z = mkvar("z"), + a0_ = x - y / 2.f, + a1 = x + (-0.5f) * y, + b0_ = x - ((y - (z / 5.f)) / 2.f), + b1 = x + (-0.5f) * y + 0.1f * z; + + SymbolVar a0, b0; + unpack_vector(run_opt({a0_, b0_}), a0, b0); + EXPECT_EQ(a1, a0); + EXPECT_EQ(b1, b0); +} + TEST_PASS(ReorderArithChainPass, 0) { auto chk = [this](SymbolVar inp, SymbolVar expect) { check(expect, inp, gopt::ConstVarType::IMMUTABLE_AND_PARAM);