From 85ff90c23782c4b771117112197f38acf32fe999 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Tue, 12 May 2020 20:45:28 +0800 Subject: [PATCH] Add input shape condition for transpose_reshape fusion pass --- .../ir_fusion/reshape_transpose_fusion.cc | 17 ++++++++++++++ .../ir_fusion/transpose_reshape_fusion.cc | 17 ++++++++++++++ .../reshape_transpose_fusion_test.cc | 23 ++++++++++++++++++- .../transpose_reshape_fusion_test.cc | 23 ++++++++++++++++++- .../reshape_transpose_fusion_test.py | 2 +- 5 files changed, 79 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc index 06f911c6b..d41d3e3c4 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -23,6 +23,18 @@ namespace mindspore { namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + const BaseRef ReshapeTransposeFusion::DefinePattern() const { const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); VectorRef reshape({prim_reshape, input_varptr_}); @@ -38,6 +50,11 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(transpose_cnode); auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_cnode); + std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + return nullptr; + } auto prim = std::make_shared(kConfusionTransposeDOpName); std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc index d991a1cd4..138fa3312 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -23,6 +23,18 @@ namespace mindspore { namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + const BaseRef TransposeReshapeFusion::DefinePattern() const { const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); VectorRef transpose({prim::kPrimTranspose, input_varptr_}); @@ -38,6 +50,11 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(reshape_cnode); auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(transpose_cnode); + std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + return nullptr; + } auto prim = std::make_shared(kConfusionTransposeDOpName); std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc index 3478e9296..59140e91a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc @@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { * return transpose */ FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); - std::vector shp{2, 4, 8, 16}; + std::vector shp{2, 2, 16, 16}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list{x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); @@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) { + /* + * def before(input0, input1): + * reshape = Reshape(input0, input1) + * transpose = Transpose(reshape) + * return transpose + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); + std::vector shp{2, 4, 8, 16}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc index 8f855b9a6..3290acd42 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc @@ -39,7 +39,7 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) { * return transpose */ FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before"); - std::vector shp{2, 4, 8, 16}; + std::vector shp{2, 2, 16, 16}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list{x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); @@ -61,5 +61,26 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_no_fusion) { + /* + * def before(input0, input1): + * reshape = Reshape(input0, input1) + * transpose = Transpose(reshape) + * return transpose + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before"); + std::vector shp{2, 4, 8, 16}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py index 0afd547da..c440deffc 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py @@ -36,7 +36,7 @@ def test_reshape_transpose_fusion(tag): @fns def before(input0): - reshape = Reshape(input0, (2, 4, 8, 16)) + reshape = Reshape(input0, (2, 2, 16, 16)) transpose = Transpose(reshape, (1, 0, 2, 3)) return transpose -- GitLab