提交 85ff90c2 编写于 作者: Y yujianfeng

Add input shape condition for transpose_reshape fusion pass

上级 298a7848
......@@ -23,6 +23,18 @@
namespace mindspore {
namespace opt {
namespace {
bool CheckShapeDimInfo(const std::vector<size_t> &shape) {
if (shape.empty()) {
return false;
}
if (shape.size() == 1 && shape[0] % kCubeSize != 0) {
return false;
}
return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0));
}
} // namespace
const BaseRef ReshapeTransposeFusion::DefinePattern() const {
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
VectorRef reshape({prim_reshape, input_varptr_});
......@@ -38,6 +50,11 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(transpose_cnode);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_cnode);
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
......
......@@ -23,6 +23,18 @@
namespace mindspore {
namespace opt {
namespace {
bool CheckShapeDimInfo(const std::vector<size_t> &shape) {
if (shape.empty()) {
return false;
}
if (shape.size() == 1 && shape[0] % kCubeSize != 0) {
return false;
}
return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0));
}
} // namespace
const BaseRef TransposeReshapeFusion::DefinePattern() const {
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
VectorRef transpose({prim::kPrimTranspose, input_varptr_});
......@@ -38,6 +50,11 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(reshape_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(transpose_cnode);
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
......
......@@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
std::vector<int> shp{2, 2, 16, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
......@@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) {
/*
* def before(input0, input1):
* reshape = Reshape(input0, input1)
* transpose = Transpose(reshape)
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
} // namespace opt
} // namespace mindspore
......@@ -39,7 +39,7 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) {
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
std::vector<int> shp{2, 2, 16, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
......@@ -61,5 +61,26 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_no_fusion) {
/*
* def before(input0, input1):
* reshape = Reshape(input0, input1)
* transpose = Transpose(reshape)
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::TransposeReshapeFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
} // namespace opt
} // namespace mindspore
......@@ -36,7 +36,7 @@ def test_reshape_transpose_fusion(tag):
@fns
def before(input0):
reshape = Reshape(input0, (2, 4, 8, 16))
reshape = Reshape(input0, (2, 2, 16, 16))
transpose = Transpose(reshape, (1, 0, 2, 3))
return transpose
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册