diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc index 3f905fedf9f2a7edbc2f31acdf0daa752a8cd9a3..4645167191407baa13b6733ed6abeeb3e9e72da0 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -42,17 +42,69 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g const BaseRef AdamApplyOneFusion::DefinePattern() const { const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_deal_div = std::make_shared(kRealDivOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); } +const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); +} + +const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h index 77f66414637652efdbc51f4434421062fabf2b99..5ee8a86cfbeaf70adf27f9ae989d288ca5b177e3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h @@ -18,21 +18,23 @@ #include #include +#include #include "pre_activate/common/optimizer.h" #include "utils/utils.h" namespace mindspore { namespace opt { -constexpr size_t kAdamApplyOneInputNum = 5; -constexpr size_t kAdamApplyOneMulInputNum = 4; +constexpr size_t kAdamApplyOneInputVarNum = 5; +constexpr size_t kAdamApplyOneMulInputVarNum = 4; class AdamApplyOneFusion : public PatternProcessPass { public: - explicit AdamApplyOneFusion(bool multigraph = true) : PatternProcessPass("adam_apply_one_fusion", multigraph) { - for (size_t i = 0; i < kAdamApplyOneInputNum; ++i) { + explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { input_vars_.push_back(std::make_shared()); } - for (size_t i = 0; i < kAdamApplyOneMulInputNum; ++i) { + for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { mul_x_input_vars_.push_back(std::make_shared()); } add2_y_ = std::make_shared(); @@ -44,7 +46,7 @@ class AdamApplyOneFusion : public PatternProcessPass { const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - private: + protected: AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; std::vector input_vars_; std::vector mul_x_input_vars_; @@ -52,6 +54,42 @@ class AdamApplyOneFusion : public PatternProcessPass { VarPtr add0_var_; VarPtr add1_var_; }; + +class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond1Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} + + ~AdamApplyOneCond1Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond2Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} + + ~AdamApplyOneCond2Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond3Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} + + ~AdamApplyOneCond3Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond4Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} + + ~AdamApplyOneCond4Fusion() override = default; + const BaseRef DefinePattern() const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc index f4e418bed112b6f2a58beb7f1b9537aa5ea0688a..c2ee7b6519bd880750170118a73cb7ced7ef0a19 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc @@ -66,5 +66,156 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_fusion) { EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond1_fusion) { + /* + * def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + * square0 = Square(input0) + * mul1 = Mul(mul1_x, input0) + * mul0 = Mul(mul0_x, input2) + * mul2 = Mul(mul2_x, input1) + * mul3 = Mul(mul3_x, square0) + * add0 = Add(mul0, mul1) + * add1 = Add(mul2, mul3) + * sqrt0 = Sqrt(add1) + * add2 = Add(add2_y, sqrt0) + * true_div0 = RealDiv(add0, add2) + * mul4 = Mul(input4, true_div0) + * sub0 = Sub(input3, mul4) + * outputs = make_tuple(add1, add0, sub0) + * output = tuple_getitem(outputs, 0) + * return output + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond1"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond2_fusion) { + /* + * def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + * square0 = Square(input0) + * mul1 = Mul(mul1_x, input0) + * mul0 = Mul(mul0_x, input2) + * mul2 = Mul(mul2_x, input1) + * mul3 = Mul(square0, mul3_x) + * add0 = Add(mul0, mul1) + * add1 = Add(mul2, mul3) + * sqrt0 = Sqrt(add1) + * add2 = Add(sqrt0, add2_y) + * true_div0 = RealDiv(add0, add2) + * mul4 = Mul(true_div0, input4) + * sub0 = Sub(input3, mul4) + * outputs = make_tuple(add1, add0, sub0) + * output = tuple_getitem(outputs, 0) + * return output + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond2"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond3_fusion) { + /* + * def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + * square0 = Square(input0) + * mul1 = Mul(mul1_x, input0) + * mul0 = Mul(mul0_x, input2) + * mul2 = Mul(mul2_x, input1) + * mul3 = Mul(mul3_x, square0) + * add0 = Add(mul0, mul1) + * add1 = Add(mul2, mul3) + * sqrt0 = Sqrt(add1) + * add2 = Add(sqrt0, add2_y) + * true_div0 = RealDiv(add0, add2) + * mul4 = Mul(true_div0, input4) + * sub0 = Sub(input3, mul4) + * outputs = make_tuple(add1, add0, sub0) + * output = tuple_getitem(outputs, 0) + * return output + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond3"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) { + /* + * def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + * square0 = Square(input0) + * mul1 = Mul(mul1_x, input0) + * mul0 = Mul(mul0_x, input2) + * mul2 = Mul(mul2_x, input1) + * mul3 = Mul(mul3_x, square0) + * add0 = Add(mul0, mul1) + * add1 = Add(mul2, mul3) + * sqrt0 = Sqrt(add1) + * add2 = Add(add2_y, sqrt0) + * true_div0 = RealDiv(add0, add2) + * mul4 = Mul(true_div0, input4) + * sub0 = Sub(input3, mul4) + * outputs = make_tuple(add1, add0, sub0) + * output = tuple_getitem(outputs, 0) + * return output + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond4"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py index b55764b18d9ae105bcb0861670e908f736822f48..225964ee387f45e82755d3b393b09cd727a62f5b 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py @@ -58,6 +58,78 @@ def test_adam_apply_one_fusion(tag): output = tuple_getitem(outputs, 0) return output + @fns + def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(add2_y, sqrt0) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(input4, true_div0) + sub0 = Sub(input3, mul4) + outputs = make_tuple(add1, add0, sub0) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(square0, mul3_x) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(sqrt0, add2_y) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(true_div0, input4) + sub0 = Sub(input3, mul4) + outputs = make_tuple(add1, add0, sub0) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(sqrt0, add2_y) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(true_div0, input4) + sub0 = Sub(input3, mul4) + outputs = make_tuple(add1, add0, sub0) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(add2_y, sqrt0) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(true_div0, input4) + sub0 = Sub(input3, mul4) + outputs = make_tuple(add1, add0, sub0) + output = tuple_getitem(outputs, 0) + return output + @fns def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): adam_apply_one = AdamApplyOne(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y)