diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index effec812e4b268023d3a129c4a0f8069ac8d176f..9482af377c75495e869f8c26bdb846ffd06bc0fb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -98,6 +98,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc index 21bbf1a5d78ac5b4bf231b0218a115b69251bb05..dc723f3052e4fe116c865715e640451ccb16d99e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -163,5 +163,128 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph } return nullptr; } + +const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h index afc48bb8c643fb58a0c14580ff2f3a822974306b..4d52451a076d77a6cc3ef7e05644196dbfc3de21 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h @@ -74,6 +74,36 @@ class LambNextMVWithDecayRule : public PatternProcessPass { VarPtr add0_var_; VarPtr add1_var_; }; + +class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} + + ~LambNextMVWithDecayRuleCond1() override = default; + const BaseRef DefinePattern() const override; + const BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} + + ~LambNextMVWithDecayRuleCond2() override = default; + const BaseRef DefinePattern() const override; + const BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} + + ~LambNextMVWithDecayRuleCond3() override = default; + const BaseRef DefinePattern() const override; + const BaseRef DefineAnotherPattern() const override; +}; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc index f152f1873177795f89b9787372340b01c215768f..01f6effa4ff64c99ba2695d3fe746bc64834d859 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc @@ -24,6 +24,8 @@ #include "pre_activate/common/helper.h" namespace mindspore { +namespace opt { +namespace { bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(add); @@ -36,6 +38,14 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { if (!opt::IsUsedByOthers(graph, cnode)) { + auto full_name = cnode->fullname_with_scope(); + // exclude lamb and adam, and only work in bert + if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || + std::string::npos == full_name.find("bert")) { + MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; + return false; + } + *mul = cnode; *mul_index = index; return true; @@ -45,8 +55,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ } return false; } - -namespace opt { +} // namespace const BaseRef MulAddFusion::DefinePattern() const { VarPtr x = std::make_shared(); VarPtr y = std::make_shared(); @@ -74,7 +83,12 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP for (size_t index = 1; index < mul->size(); ++index) { inputs.push_back(mul->input(index)); } - inputs.push_back(add->input(add->size() - mul_index)); + auto another_input_node = add->input(add->size() - mul_index); + if (IsUsedByOthers(graph, another_input_node)) { + MS_LOG(INFO) << "Add's another input node is used by others, do not fuse"; + return nullptr; + } + inputs.push_back(another_input_node); auto fusion_node = graph->NewCNode(inputs); fusion_node->set_scope(add->scope()); fusion_node->set_abstract(add->abstract()); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc index f114c772166e2e1d6acae88385e57280efc52218..53ebbd6f2f003c668f8faaf02363f7f711e2c577 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc @@ -253,5 +253,134 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1_un_match) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + auto origin_graph = std::make_shared(*fg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match"); + EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + DumpIR("fg.ir", fg, true); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "after"); + DumpIR("g_after.ir", g_after, true); + DumpIR("new_graph.ir", new_graph, true); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2_un_match) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + auto origin_graph = std::make_shared(*fg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match"); + EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + DumpIR("fg.ir", fg, true); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "after"); + DumpIR("g_after.ir", g_after, true); + DumpIR("new_graph.ir", new_graph, true); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3_un_match) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + auto origin_graph = std::make_shared(*fg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match"); + EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc index 50221fca19432307ef4f905241f67b0806454fbd..87bb21f89a614397f3b75132d34e6d2a4a8f8826 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc @@ -37,6 +37,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) { args_spec_list.push_back(x_abstract); } auto fg = GetKernelGraph(g, args_spec_list); + auto scope = std::make_shared("bert"); + for (auto nd : fg->execution_order()) { + nd->set_scope(scope); + } auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -57,6 +61,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) { args_spec_list.push_back(x_abstract); } auto fg = GetKernelGraph(g, args_spec_list); + auto scope = std::make_shared("bert"); + for (auto nd : fg->execution_order()) { + nd->set_scope(scope); + } auto optimizer = std::make_shared(); auto pm = std::make_shared(); diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py index d9d1ab5c394fb50d276b81a9ce6875d9b35abb30..d2931cce36f9d1e775b7e4ac366a161df1f92117 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py @@ -174,3 +174,201 @@ def test_lamb_next_mv_with_decay_rule(tag): return output return fns[tag] + +def test_lamb_next_mv_with_decay_rule_cond1(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul1 = Mul(input3, constant_mul1_sub) + mul0 = Mul(input4, constant_mul0_x) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(input0, constant_mul3_sub1) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt1 = Sqrt(real_div1) + real_div0 = RealDiv(add0, input5) + add4 = Add(sqrt1, constant_add2_y) + sqrt0 = Rsqrt(add2) + mul4 = Mul(constant_mul4_x, input6) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + add5 = Add(mul4, real_div4) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, add5) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, + constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, + constant_add2_y) + outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), + tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + @fns + def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul1 = Mul(input3, constant_mul1_sub) + mul0 = Mul(input4, constant_mul0_x) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(input0, constant_mul3_sub1) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt1 = Sqrt(real_div1) + real_div0 = RealDiv(add0, input5) + add4 = Add(sqrt1, constant_add2_y) + sqrt0 = Rsqrt(add2) + mul4 = Mul(constant_mul4_x, input6) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + add5 = Add(mul4, real_div4) + # un match + add3 = Add(real_div2, mul4) + outputs = make_tuple(add3, add0, add1, add5) + output = tuple_getitem(outputs, 0) + return output + + return fns[tag] + +def test_lamb_next_mv_with_decay_rule_cond2(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul1 = Mul(constant_mul1_sub, input3) + mul0 = Mul(constant_mul0_x, input4) + add0 = Add(mul0, mul1) + mul2 = Mul(constant_mul2_x, input1) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt1 = Sqrt(real_div1) + real_div0 = RealDiv(add0, input5) + add4 = Add(constant_add2_y, sqrt1) + sqrt0 = Rsqrt(add2) + mul4 = Mul(constant_mul4_x, input6) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + add5 = Add(mul4, real_div4) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, add5) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, + constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, + constant_add2_y) + outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), + tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + @fns + def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul1 = Mul(constant_mul1_sub, input3) + mul0 = Mul(constant_mul0_x, input4) + add0 = Add(mul0, mul1) + mul2 = Mul(constant_mul2_x, input1) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt1 = Sqrt(real_div1) + real_div0 = RealDiv(add0, input5) + add4 = Add(constant_add2_y, sqrt1) + sqrt0 = Rsqrt(add2) + mul4 = Mul(constant_mul4_x, input6) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + add5 = Add(mul4, real_div4) + # un_match + add3 = Add(real_div2, mul4) + outputs = make_tuple(add3, add0, add1, add5) + output = tuple_getitem(outputs, 0) + return output + + return fns[tag] + +def test_lamb_next_mv_with_decay_rule_cond3(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul1 = Mul(input3, constant_mul1_sub) + mul0 = Mul(input4, constant_mul0_x) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(real_div1, constant_add2_y) + sqrt1 = Sqrt(real_div1) + real_div0 = RealDiv(add0, input5) + add4 = Add(sqrt1, constant_add2_y) + sqrt0 = Rsqrt(add2) + mul4 = Mul(input6, constant_mul4_x) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + add5 = Add(mul4, real_div4) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, add5) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, + constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, + constant_add2_y) + outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), + tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + @fns + def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul1 = Mul(input3, constant_mul1_sub) + mul0 = Mul(input4, constant_mul0_x) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(real_div1, constant_add2_y) + sqrt1 = Sqrt(real_div1) + real_div0 = RealDiv(add0, input5) + add4 = Add(sqrt1, constant_add2_y) + sqrt0 = Rsqrt(add2) + mul4 = Mul(input6, constant_mul4_x) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + add5 = Add(mul4, real_div4) + # un match + add3 = Add(real_div2, mul4) + outputs = make_tuple(add3, add0, add1, add5) + output = tuple_getitem(outputs, 0) + return output + + return fns[tag]