diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index ecbdde0a096147f78f6fa43782b6bd25adbe0aeb..f01dd95f060be041f9bfc2ea46c376f7b8b12822 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -100,6 +100,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc index 78e6856d5a5e595336ca26892a5940b13c153bbc..d49b2d47f36b289ba44e08b7059760f49362793c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -74,10 +74,21 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An } bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, - const AnfNodePtr &reduce_sum) { + const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { MS_EXCEPTION_IF_NULL(mul0_anf); MS_EXCEPTION_IF_NULL(mul1_anf); MS_EXCEPTION_IF_NULL(reduce_sum); + MS_EXCEPTION_IF_NULL(input2); + auto addn = input2->cast(); + if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { + MS_LOG(INFO) << "mul's second input is not addn"; + return true; + } + std::vector shape = AnfAlgo::GetOutputInferShape(addn, 0); + if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { + MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; + return true; + } if (!mul0_anf->isa() || !mul1_anf->isa()) { return true; } @@ -86,11 +97,6 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf auto mul0 = mul0_anf->cast(); MS_EXCEPTION_IF_NULL(mul0); - // when network is _VirtualDatasetCell, quit fusion - if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) { - return true; - } - if (IsDepend(graph, mul0->input(1), reduce_sum)) { MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; return true; @@ -128,7 +134,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; return nullptr; } - if (QuitFusion(graph, mul0, mul1, node)) { + if (QuitFusion(graph, mul0, mul1, node, input2)) { return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc index 01f6effa4ff64c99ba2695d3fe746bc64834d859..2536255fc1f316d46c5d6c386bc489683fd68e5b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc @@ -84,8 +84,9 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP inputs.push_back(mul->input(index)); } auto another_input_node = add->input(add->size() - mul_index); - if (IsUsedByOthers(graph, another_input_node)) { - MS_LOG(INFO) << "Add's another input node is used by others, do not fuse"; + if (another_input_node->isa() && + AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) { + MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse"; return nullptr; } inputs.push_back(another_input_node); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc index e3bf09d2cbd594face80e492367c386ae1199b1a..20448578418fa05aea06ab28e85b31ae766bde1e 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc @@ -32,7 +32,7 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); EXPECT_NE(g, nullptr); - std::vector shp{1, 1, 1, 1}; + std::vector shp{10, 1024}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; for (size_t i = 0; i < 3; ++i) { @@ -49,6 +49,5 @@ TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } - } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py index 2c7f555c5dfb8775c1dad45f77b05a727403b322..2f834abe7773b8195673f8fdd1b618b3f9e94709 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py @@ -15,12 +15,13 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P +addn = P.AddN() mul = P.Mul() reduce_sum = P.ReduceSum() confusion_mul_grad = Primitive('ConfusionMulGrad') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') -axis = 2 +axis = 1 class FnDict: @@ -39,8 +40,10 @@ def test_confusion_mul_grad_fusion(tag): @fns def before(input1, input2, input3): - output1 = mul(input1, input2) - mul1 = mul(input3, input2) + addn0 = addn((input2, input2)) + + output1 = mul(input1, addn0) + mul1 = mul(input3, addn0) # input axis will be convert to attr in step ConstructKernelGraph output2 = reduce_sum(mul1, axis) res = make_tuple(output1, output2) @@ -48,7 +51,8 @@ def test_confusion_mul_grad_fusion(tag): @fns def after(input1, input2, input3): - res = confusion_mul_grad(input1, input2, input3) + addn0 = addn(input2, input2) + res = confusion_mul_grad(input1, addn0, input3) item0 = tuple_getitem(res, 0) item1 = tuple_getitem(res, 1) res = make_tuple(item0, item1)