diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 4fadb4f37b98c901ad7ac5bc3093ef3b84810f5a..3951e1a13270c2f251d1adee2383a77d99620062 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -70,7 +70,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); if (AnfAlgo::IsFeatureMapInput(cnode, index) && - kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) { + kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) { priority_matched_format = !is_init ? pre_output_format : priority_matched_format; is_init = true; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 982538c417837b2c6e87dfaba4529bbbf1b3d888..4d18e3b28a7434610386dea65de7f7471bd9aa2e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -31,6 +31,7 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { +const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { std::vector trans_inputs; @@ -110,13 +111,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & MS_EXCEPTION_IF_NULL(input_node); AnfAlgo::SetNodeInput(node, input_node, index); } - if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index) - << "when inserting the transdata node " << node->DebugString(); - } std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); std::string dest_format = AnfAlgo::GetInputFormat(node, index); - if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) << " To DefaultFormat , index: " << index; return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); @@ -133,7 +130,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " << node->DebugString(); } - if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); } @@ -154,7 +151,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); } else { // No need insert trans op. diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc index 39378c16522b949785b6df685565fc6fd9f966c7..32e4987f5a29493bb3d362b472ab13e7cdd19afd 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -97,7 +97,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_ std::string convert_format; for (const auto &do_mask : do_mask_node_list) { auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); - if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) { + if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) { special_format = do_mask_data_format; } if (format_counter.find(do_mask_data_format) == format_counter.end()) { @@ -111,7 +111,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_ convert_format = kOpFormat_DEFAULT; break; } - if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() && + if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() && special_format != do_mask_data_format) { convert_format = kOpFormat_DEFAULT; break; @@ -133,7 +133,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map kOptOperatorSet = { kApplyRMSPropOpName, }; -const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, +const std::set kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; diff --git a/mindspore/ops/_op_impl/tbe/trans_data.py b/mindspore/ops/_op_impl/tbe/trans_data.py index f961491b3723f8f2b78c232ad716a4b46b3d40bc..64cca3dc44c90149b26745a56354b54c235f3ed4 100644 --- a/mindspore/ops/_op_impl/tbe/trans_data.py +++ b/mindspore/ops/_op_impl/tbe/trans_data.py @@ -58,6 +58,8 @@ trans_data_op_info = TBERegOp("TransData") \ .dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \ .dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \ + .dtype_format(DataType.F32_Default, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_HWCN, DataType.F32_Default) \ .get_op_info()