diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 16a82500b43f94acbd71f73567b7d492e693ae6f..bced5e6780682517bccd6365758abdf2fb846b20 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -201,6 +201,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc index 7641772d7a2588eb50ccf33b4faef6d0c42337df..de8a7d9d55b0b03f8162d08ffcdf762b1a182802 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -277,5 +277,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c } return bn_training_update_outputs[0]; } + +const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); + VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); + VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); + VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); + VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); + VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h index e6bf1dda554d7dd29065f8fba60841d4f2b3f2b8..e4b31ca5f4f74b517d01d5996b06b650fa4af1de 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h @@ -18,6 +18,7 @@ #include #include +#include #include "pre_activate/common/optimizer.h" #include "utils/utils.h" @@ -25,8 +26,8 @@ namespace mindspore { namespace opt { class FusedBatchNormFusion : public PatternProcessPass { public: - explicit FusedBatchNormFusion(bool multigraph = true) - : PatternProcessPass("fused_batch_norm_fusion", multigraph), + explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph), data_input0_var_(std::make_shared()), data_input1_var_(std::make_shared()), data_input2_var_(std::make_shared()), @@ -39,7 +40,7 @@ class FusedBatchNormFusion : public PatternProcessPass { const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - private: + protected: AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const; void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector &bn_training_reduce_outputs, @@ -59,6 +60,15 @@ class FusedBatchNormFusion : public PatternProcessPass { VarPtr constant_input1_var_; VarPtr batch_norm_var_; }; + +class FusedBatchNormMixPrecisionFusion : public FusedBatchNormFusion { + public: + explicit FusedBatchNormMixPrecisionFusion(bool multigraph = true) + : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} + + ~FusedBatchNormMixPrecisionFusion() override = default; + const BaseRef DefinePattern() const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc index 3d13f4a336e68ff64b41345f0a5d38243d776ee6..f023446698ceff7074e1947f338b9df513d84cb1 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc @@ -50,5 +50,28 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision"); + EXPECT_NE(g, nullptr); + std::vector shp_x{32, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + std::vector shp_y{64}; + auto y_abstract = std::make_shared(kFloat32, shp_y); + AbstractBasePtrList args_spec_list{x_abstract}; + for (size_t i = 0; i < 6; ++i) { + args_spec_list.push_back(y_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py index ca93d40443644b9c7fe878cc5d3aa188280d3034..f510956b21edec1aa037da4ef4470b6915b6ea6e 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py @@ -24,6 +24,7 @@ make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') depend = Primitive('depend') BatchNorm = P.BatchNorm() +Cast = P.Cast() BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingUpdate = Primitive('BNTrainingUpdate') constant0 = Tensor(0.1, mstype.float32) @@ -59,6 +60,21 @@ def test_fused_batch_norm_fusion(tag): output = tuple_getitem(outputs, 0) return output + @fns + def before_mix_precision(input0, input1, input2, input3, input4, var0, var1): + batch_norm = BatchNorm(input0, input1, input2, input3, input4) + sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) + sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) + mul0 = Mul(sub0, constant0) + mul1 = Mul(sub1, constant1) + assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32)) + assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32)) + depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0) + depend1 = depend(depend0, assign_sub1) + outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) + output = tuple_getitem(outputs, 0) + return output + @fns def after(input0, input1, input2, input3, input4, var0, var1): bn_training_reduce = BNTrainingReduce(input0)