From 9b7df3d09908abd95a8033954b87541f160b5a95 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Fri, 21 Aug 2020 17:27:22 +0800 Subject: [PATCH] gpu optimize transpose --- .../gpu/insert_format_transform_op.cc | 36 ++++++++++++++++--- .../optimizer/gpu/replace_bn_cast_fusion.cc | 2 +- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/insert_format_transform_op.cc b/mindspore/ccsrc/backend/optimizer/gpu/insert_format_transform_op.cc index e1ab13bbd..00e75bea9 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/insert_format_transform_op.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/insert_format_transform_op.cc @@ -33,7 +33,25 @@ std::vector TransposeAxis(const std::string &src_format, const std::string } else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) { return {0, 3, 1, 2}; } else { - MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format; + MS_LOG(EXCEPTION) << "Invalid format transform, from " << src_format << " to " << dst_format; + } +} + +// Transpose can be replaceed by nop reshape in some situations. +// 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1} +// 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2} +bool IsFakeTranspose(const std::vector &out_shape, const std::vector &transpose_perm) { + if (out_shape.size() != 4) { + MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D."; + } + std::vector perm1 = {0, 2, 3, 1}; + std::vector perm2 = {0, 3, 1, 2}; + if (transpose_perm == perm1) { + return (out_shape[1] == 1 && out_shape[2] == 1); + } else if (transpose_perm == perm2) { + return (out_shape[2] == 1 && out_shape[3] == 1); + } else { + return false; } } @@ -56,8 +74,16 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, int used_node_index, const std::vector &transpose_perm) { MS_EXCEPTION_IF_NULL(graph); - // 1.Create a transpose node. - auto transpose_prim = std::make_shared(prim::kPrimTranspose->name()); + // 0.Judge whether it is a fake transpose + auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index); + bool is_fake = IsFakeTranspose(transed_shape, transpose_perm); + // 1.Create a transpose node or a fake transpose node:reshape. + mindspore::PrimitivePtr transpose_prim; + if (is_fake) { + transpose_prim = std::make_shared(prim::kPrimReshape->name()); + } else { + transpose_prim = std::make_shared(prim::kPrimTranspose->name()); + } MS_EXCEPTION_IF_NULL(transpose_prim); // 2.Set the input of transpose. std::vector transpose_input = {NewValueNode(transpose_prim), node}; @@ -66,7 +92,9 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)}; AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get()); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); + if (!is_fake) { + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); + } // 4.Set the input of used_node. MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() << ", index: " << used_node_index; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc index f594320d9..85a0aada6 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc @@ -57,7 +57,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A if (item_idx == 0) { auto cast = GetRealNodeUsedList(graph, outlist->at(i).first); if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { - return nullptr; + continue; } manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(i).first)); outputs_type.push_back(kNumberTypeFloat16); -- GitLab