diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc index 47098379bf67b426ac5b482a3130f53407e87e7a..caea9599c1c7fcecce3ed6019412e1890a4a0ffb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -72,6 +72,38 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An } return mul0; } + +bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(reduce_sum); + if (!mul0_anf->isa()) { + return true; + } + auto mul0 = mul0_anf->cast(); + MS_EXCEPTION_IF_NULL(mul0); + + // when network is _VirtualDatasetCell, quit fusion + if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) { + return true; + } + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(reduce_sum) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum]; + auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair &node_index) { + return node_index.first == mul0->input(1) || node_index.first == mul0; + }); + if (it != outputs_set.end()) { + MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle"; + return true; + } + + return false; +} } // namespace const BaseRef ConfusionMulGradFusion::DefinePattern() const { @@ -90,9 +122,6 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons auto reduce_sum = node->cast(); MS_EXCEPTION_IF_NULL(reduce_sum); auto mul1 = reduce_sum->input(1); - if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) { - return nullptr; - } if (IsUsedByOthers(graph, mul1)) { MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; return nullptr; @@ -102,6 +131,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; return nullptr; } + if (QuitFusion(graph, mul0, node)) { + return nullptr; + } auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); std::vector fusion_node_outputs; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc index 4b5d38d3757766d28aa78c7b4375c6fcb8db1d39..e3bf09d2cbd594face80e492367c386ae1199b1a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc @@ -32,11 +32,6 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); EXPECT_NE(g, nullptr); - auto bert_scope = std::make_shared("bert/encoder"); - for (auto node : TopoSort(g->get_return())) { - node->set_scope(bert_scope); - } - std::vector shp{1, 1, 1, 1}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list;