diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc index 06f911c6be2008626a5eaa525999f78cbb0dcdcd..d41d3e3c4a2d0af31e77028c356d06dfa9357172 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -23,6 +23,18 @@ namespace mindspore { namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + const BaseRef ReshapeTransposeFusion::DefinePattern() const { const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); VectorRef reshape({prim_reshape, input_varptr_}); @@ -38,6 +50,11 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(transpose_cnode); auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_cnode); + std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + return nullptr; + } auto prim = std::make_shared(kConfusionTransposeDOpName); std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc index d991a1cd4a322463076ed0b64b30e87c27bdc1a3..138fa331288b42357649022c730a01f991d56951 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -23,6 +23,18 @@ namespace mindspore { namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + const BaseRef TransposeReshapeFusion::DefinePattern() const { const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); VectorRef transpose({prim::kPrimTranspose, input_varptr_}); @@ -38,6 +50,11 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(reshape_cnode); auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(transpose_cnode); + std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + return nullptr; + } auto prim = std::make_shared(kConfusionTransposeDOpName); std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc index 3478e92968e9d02a83e296382c12cd6ccf748b41..59140e91a184082a58dcf4c2a7ca7205bf64b391 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc @@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { * return transpose */ FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); - std::vector shp{2, 4, 8, 16}; + std::vector shp{2, 2, 16, 16}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list{x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); @@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) { + /* + * def before(input0, input1): + * reshape = Reshape(input0, input1) + * transpose = Transpose(reshape) + * return transpose + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); + std::vector shp{2, 4, 8, 16}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc index 8f855b9a6e5cb0c8d6474e9d1e2018c9c337f6d9..3290acd42f3dca631287ada5bc5c828eeb607679 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc @@ -39,7 +39,7 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) { * return transpose */ FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before"); - std::vector shp{2, 4, 8, 16}; + std::vector shp{2, 2, 16, 16}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list{x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); @@ -61,5 +61,26 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_no_fusion) { + /* + * def before(input0, input1): + * reshape = Reshape(input0, input1) + * transpose = Transpose(reshape) + * return transpose + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before"); + std::vector shp{2, 4, 8, 16}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py index 0afd547da0bd8fe768bfdd529418684e1416239f..c440deffcaefcf464b247b3f39717be00167ffa6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py @@ -36,7 +36,7 @@ def test_reshape_transpose_fusion(tag): @fns def before(input0): - reshape = Reshape(input0, (2, 4, 8, 16)) + reshape = Reshape(input0, (2, 2, 16, 16)) transpose = Transpose(reshape, (1, 0, 2, 3)) return transpose