diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 3ea8a300e37589b9800ad1d0ff140900258360bf..dcca95fbc02a2f6d75acccdd3e6b4b4a211931af 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -125,6 +125,10 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc index 189ac94546a4cefe39da836a3e3dfdbd8a4f36a5..9164484b30c6557d665feaebc6590181b2084187 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -15,30 +15,9 @@ */ #include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" #include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore { namespace opt { -AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto prim = std::make_shared(kAdamApplyOneOpName); - std::vector new_node_inputs = {NewValueNode(prim)}; - for (const auto &input_var : input_vars_) { - auto input_node = utils::cast((*equiv)[input_var]); - MS_EXCEPTION_IF_NULL(input_node); - new_node_inputs.push_back(input_node); - } - for (const auto &mul_x_input_var : mul_x_input_vars_) { - auto mul_x_input_node = utils::cast((*equiv)[mul_x_input_var]); - MS_EXCEPTION_IF_NULL(mul_x_input_node); - new_node_inputs.push_back(mul_x_input_node); - } - auto add2_y_node = utils::cast((*equiv)[add2_y_]); - MS_EXCEPTION_IF_NULL(add2_y_node); - new_node_inputs.push_back(add2_y_node); - auto new_node = func_graph->NewCNode(new_node_inputs); - return new_node; -} - const BaseRef AdamApplyOneFusion::DefinePattern() const { const auto prim_sqrt = std::make_shared(kSqrtOpName); const auto prim_real_div = std::make_shared(kRealDivOpName); @@ -104,16 +83,152 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); } +const BaseRef AdamApplyOneAssignFusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const AnfNodePtr &final_node) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + PrimitivePtr prim = nullptr; + if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) { + prim = std::make_shared(kAdamApplyOneAssignOpName); + } else { + prim = std::make_shared(kAdamApplyOneOpName); + } + std::vector new_node_inputs = {NewValueNode(prim)}; + for (const auto &input_var : input_vars_) { + auto input_node = utils::cast((*equiv)[input_var]); + MS_EXCEPTION_IF_NULL(input_node); + new_node_inputs.push_back(input_node); + } + for (const auto &mul_x_input_var : mul_x_input_vars_) { + auto mul_x_input_node = utils::cast((*equiv)[mul_x_input_var]); + MS_EXCEPTION_IF_NULL(mul_x_input_node); + new_node_inputs.push_back(mul_x_input_node); + } + auto add2_y_node = utils::cast((*equiv)[add2_y_]); + MS_EXCEPTION_IF_NULL(add2_y_node); + new_node_inputs.push_back(add2_y_node); + auto new_node = func_graph->NewCNode(new_node_inputs); + return new_node; +} + const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + auto sub0 = node; + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { + auto iter_sub0 = (*equiv).find(sub0_var_); + if (iter_sub0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched."; + } + sub0 = utils::cast(iter_sub0->second); + } + MS_EXCEPTION_IF_NULL(sub0); + if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) { return nullptr; } - auto new_node = CreateAdamApplyOneNode(func_graph, equiv); + auto new_node = CreateAdamApplyOneNode(func_graph, equiv, node); MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(node->scope()); + new_node->set_scope(sub0->scope()); // Set abstract of new node AbstractBasePtrList new_node_abstract_list; auto iter_add0 = (*equiv).find(add0_var_); @@ -130,7 +245,7 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con MS_EXCEPTION_IF_NULL(add1); new_node_abstract_list.push_back(add1->abstract()); new_node_abstract_list.push_back(add0->abstract()); - new_node_abstract_list.push_back(node->abstract()); + new_node_abstract_list.push_back(sub0->abstract()); auto abstract_tuple = std::make_shared(new_node_abstract_list); new_node->set_abstract(abstract_tuple); // Create tuple_getitem node for outputs diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h index 2fe2d32fa09f646b366b05150721e7da1cdbd764..4c5649f978853053c9a9cb6042bd13a5df893d67 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h @@ -40,6 +40,7 @@ class AdamApplyOneFusion : public PatternProcessPass { add2_y_ = std::make_shared(); add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + sub0_var_ = std::make_shared(std::make_shared(prim::kPrimSub->name())); } ~AdamApplyOneFusion() override = default; @@ -47,12 +48,14 @@ class AdamApplyOneFusion : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; protected: - AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; + AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const AnfNodePtr &final_node) const; std::vector input_vars_; std::vector mul_x_input_vars_; VarPtr add2_y_; VarPtr add0_var_; VarPtr add1_var_; + VarPtr sub0_var_; }; class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { @@ -90,6 +93,51 @@ class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { ~AdamApplyOneCond4Fusion() override = default; const BaseRef DefinePattern() const override; }; + +class AdamApplyOneAssignFusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneAssignFusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_assign_fusion", multigraph) {} + + ~AdamApplyOneAssignFusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneAssignCond1Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneAssignCond1Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_assign_cond1_fusion", multigraph) {} + + ~AdamApplyOneAssignCond1Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneAssignCond2Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneAssignCond2Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_assign_cond2_fusion", multigraph) {} + + ~AdamApplyOneAssignCond2Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneAssignCond3Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneAssignCond3Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_assign_cond3_fusion", multigraph) {} + + ~AdamApplyOneAssignCond3Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneAssignCond4Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneAssignCond4Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_assign_cond4_fusion", multigraph) {} + + ~AdamApplyOneAssignCond4Fusion() override = default; + const BaseRef DefinePattern() const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 6be32a3df59d133a35808bcba086d25e5fe2f358..c1f551258b6d2d81b5b02563be95849c34255c44 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -119,6 +119,7 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; constexpr auto kBatchNormGradOpName = "BatchNormGrad"; constexpr auto kBNInferOpName = "BNInfer"; constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; +constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign"; constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; constexpr auto kFusedMulAddOpName = "FusedMulAdd"; constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc index 2759864037970bbbdbd316530fb2e87bf9c39783..add4ef8e41ad8f049421f1c9107fec4ee0f7dfac 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc @@ -217,5 +217,105 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond1_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond1"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond2_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond2"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond3_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond3"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond4_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond4"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 10; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py index 27cfe38b8f00f0051716ed5f3fe14ad7ae99fb9c..654b922c251060308248f74f2abdfa0c9a77d7ff 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py @@ -14,6 +14,7 @@ # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P +from mindspore.ops import functional as F Add = P.TensorAdd() Sub = P.Sub() @@ -21,9 +22,11 @@ Mul = P.Mul() RealDiv = P.RealDiv() Sqrt = P.Sqrt() Square = P.Square() +Assign = P.Assign() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') AdamApplyOne = Primitive('AdamApplyOne') +AdamApplyOneAssign = Primitive('AdamApplyOneAssign') class FnDict: @@ -139,3 +142,138 @@ def test_adam_apply_one_fusion(tag): return make_tuple(output) return fns[tag] + + +def test_adam_apply_one_assign_fusion(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(sqrt0, add2_y) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(input4, true_div0) + sub0 = Sub(input3, mul4) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + outputs = make_tuple(add1, add0, depend2) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(add2_y, sqrt0) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(input4, true_div0) + sub0 = Sub(input3, mul4) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + outputs = make_tuple(add1, add0, depend2) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(square0, mul3_x) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(sqrt0, add2_y) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(true_div0, input4) + sub0 = Sub(input3, mul4) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + outputs = make_tuple(add1, add0, depend2) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(sqrt0, add2_y) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(true_div0, input4) + sub0 = Sub(input3, mul4) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + outputs = make_tuple(add1, add0, depend2) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + square0 = Square(input0) + mul1 = Mul(mul1_x, input0) + mul0 = Mul(mul0_x, input2) + mul2 = Mul(mul2_x, input1) + mul3 = Mul(mul3_x, square0) + add0 = Add(mul0, mul1) + add1 = Add(mul2, mul3) + sqrt0 = Sqrt(add1) + add2 = Add(add2_y, sqrt0) + true_div0 = RealDiv(add0, add2) + mul4 = Mul(true_div0, input4) + sub0 = Sub(input3, mul4) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + outputs = make_tuple(add1, add0, depend2) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): + adam_apply_one_assign = AdamApplyOneAssign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, + mul3_x, add2_y) + outputs = make_tuple(tuple_getitem(adam_apply_one_assign, 0), tuple_getitem(adam_apply_one_assign, 1), + tuple_getitem(adam_apply_one_assign, 2)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + return fns[tag]