diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index a314668c950a458de7f0cb68df3642ac21af7ba3..704763404d8ca3ea78736207a41fbf7486c86a6e 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -503,6 +503,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { std::vector> kernel_info_list; + std::vector> aicpu_kernel_info_list; MS_EXCEPTION_IF_NULL(kernel_node); kernel::KernelQuery(kernel_node, &kernel_info_list); auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); @@ -510,7 +511,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { if (select_status == kNoMatched) { MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - kernel::AICPUQuery(kernel_node, &kernel_info_list); + kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); } @@ -518,6 +519,15 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { if (select_status == kNoMatched) { std::ostringstream buffer; PrintInputAndOutputInferType(buffer, kernel_node); + MS_LOG(WARNING) << "=========================kernel info list====================================="; + for (size_t index = 0; index < kernel_info_list.size(); ++index) { + MS_LOG(WARNING) << "kernel [" << index << "] :" << kernel_info_list[index]->ToString(); + } + for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) { + MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index) + << "] :" << aicpu_kernel_info_list[index]->ToString(); + } + MS_LOG(WARNING) << "========================= end ===================================="; MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() << "] cannot find valid kernel info, not supported the type " << buffer.str(); } diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index 9c0272dd7aef0e6218a8169187bd97e12d6e5a45..ce7164a0d167ced50fd56d047562b7b1ad92090c 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -110,9 +110,9 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); } -bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); } +bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); } -bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); } +bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 6afd2bef6844488f1603c5125b14f56c1697257d..f8523e94e8d0e70414581720425389814b566700 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -56,6 +56,11 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorempty()) { AicpuMetadataInfo(kernel_node, kernel_info_list); + if (!kernel_info_list->empty()) { + MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); + } } if (kernel_info_list->empty()) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 05d36fd4f22e32b9ccddfad7537dde2f3c7ce12b..f95882994e740e36e5a57ecf830bfa54e30aa577 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -31,54 +31,6 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { -kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &node, const TypeId device_type, - const kernel::KernelBuildInfo &ori_build_info) { - KernelBuildInfoBuilder builder; - builder.SetInputsFormat({input_format}); - builder.SetOutputsFormat({output_format}); - builder.SetInputsDeviceType({device_type}); - builder.SetOutputsDeviceType({device_type}); - builder.SetKernelType(ori_build_info.kernel_type()); - builder.SetFusionType(ori_build_info.fusion_type()); - builder.SetProcessor(ori_build_info.processor()); - return builder.Build(); -} - -CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(input); - std::vector trans_inputs; - auto prim = std::make_shared(op_name); - trans_inputs.push_back(NewValueNode(prim)); - trans_inputs.push_back(input); - CNodePtr trans_node = func_graph->NewCNode(trans_inputs); - MS_EXCEPTION_IF_NULL(trans_node); - std::vector padding_axis; - padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); - if (need_padding) { - // if need padding we should set the transdata node's shape to the padding shape - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, - trans_node.get()); - } else { - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); - } - // special handle for ut - if (trans_node->kernel_info() == nullptr) { - auto kernel_info = std::make_shared(); - trans_node->set_kernel_info(kernel_info); - } - MS_EXCEPTION_IF_NULL(kernel_select); - kernel_select->SelectKernel(trans_node); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); - MS_EXCEPTION_IF_NULL(trans_node); - trans_node->set_scope(input->scope()); - return trans_node; -} - AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { std::vector trans_inputs; @@ -94,6 +46,58 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i return reshape; } +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { + AnfNodePtr trans_node = nullptr; + AnfNodePtr input_node = node; + CNodePtr trans_data = nullptr; + std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); + std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; + TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); + std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); + MS_EXCEPTION_IF_NULL(node); + // if insert transdata for input we need to change the input + if (is_insert_input) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; + } + auto cnode = node->cast(); + dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); + dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); + input_node = AnfAlgo::GetInputNode(cnode, insert_index); + padding_axis = AnfAlgo::GetInputReshapeType(node, 0); + } + bool need_padding = false; + if (is_insert_input) { + need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + } else { + need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + } + if (!need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else if (is_insert_input) { + // if need padding & is input need insert a transdata + // reshape[padding shape] -> transdata[padding shape] -> node + auto padding_shape = + trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); + auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else { + // if need padding & is output need insert a transdata + // node -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + auto reshape_node = + CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); + trans_node = reshape_node; + } + // refresh the transdata's format to ori format & dst format + RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis); + return trans_node; +} + AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(node); @@ -111,13 +115,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & << "when inserting the transdata node " << node->DebugString(); } std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); - std::string origin_format = kOpFormat_DEFAULT; std::string dest_format = AnfAlgo::GetInputFormat(node, index); if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) << " To DefaultFormat , index: " << index; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, - true); + return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); } return input_node; } @@ -131,12 +133,9 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " << node->DebugString(); } - std::string origin_format = output_format; - std::string dest_format = kOpFormat_DEFAULT; if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, - false); + return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); } return node; } @@ -155,10 +154,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - std::string dest_format = kOpFormat_DEFAULT; if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { - make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, - dest_format, kTransDataOpName, false)); + make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); } else { // No need insert trans op. make_tuple_inputs.push_back(tuple_getitem); @@ -168,62 +165,54 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const return make_tuple; } } // namespace -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, - const std::string &origin_format, const std::string &dest_format, - const std::string &op_name, bool is_insert_input) { - AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = node; - AnfNodePtr trans_data = nullptr; - TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); - MS_EXCEPTION_IF_NULL(node); - if (origin_format.empty() || dest_format.empty()) { - MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; - } - // if insert transdata for input we need to change the input - if (is_insert_input) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; - } - auto cnode = node->cast(); - dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); - MS_EXCEPTION_IF_NULL(cnode); - input_node = AnfAlgo::GetInputNode(cnode, insert_index); - } - bool need_padding = false; - if (is_insert_input) { - need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && - op_name == kTransDataOpName); +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, + const AnfNodePtr &trans_data, const std::vector &reshape_type) { + MS_EXCEPTION_IF_NULL(trans_data); + MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); + auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({input_format}); + builder.SetInputReshapeType({reshape_type}); + builder.SetInputReshapeType({reshape_type}); + builder.SetOutputsFormat({output_format}); + builder.SetInputsDeviceType({device_type}); + builder.SetOutputsDeviceType({device_type}); + builder.SetKernelType(ori_build_info->kernel_type()); + builder.SetFusionType(ori_build_info->fusion_type()); + builder.SetProcessor(ori_build_info->processor()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get()); +} + +CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, + const bool need_padding, const std::string &op_name) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input); + std::vector trans_inputs; + auto prim = std::make_shared(op_name); + trans_inputs.push_back(NewValueNode(prim)); + trans_inputs.push_back(input); + CNodePtr trans_node = func_graph->NewCNode(trans_inputs); + MS_EXCEPTION_IF_NULL(trans_node); + auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); + if (need_padding) { + // if need padding we should set the transdata node's shape to the padding shape + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, + trans_node.get()); } else { - need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && - op_name == kTransDataOpName); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); } - if (!need_padding) { - // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); - trans_node = trans_data; - } else if (is_insert_input) { - // if need padding & is input need insert a transdata - // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); - auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name); - trans_node = trans_data; - } else { - // if need padding & is output need insert a transdata - // node -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); - auto reshape_node = - CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); - trans_node = reshape_node; + // special handle for ut + if (trans_node->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + trans_node->set_kernel_info(kernel_info); } - // refresh the transdata's format to ori format & dst format - MS_EXCEPTION_IF_NULL(trans_data); - MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); - auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); - auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); + MS_EXCEPTION_IF_NULL(kernel_select); + kernel_select->SelectKernel(trans_node); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); + MS_EXCEPTION_IF_NULL(trans_node); + trans_node->set_scope(input->scope()); return trans_node; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index a5463131b4709871bfe5c680957e58f07cf461bb..66e3f2ad330fb69d8be86426581afa50440ad96f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -58,11 +58,11 @@ class KernelQuery { } }; using KernelQueryPtr = std::shared_ptr; +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, + const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, - const std::string &origin_format, const std::string &dest_format, - const std::string &op_name, bool is_insert_input); +CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, + const bool need_padding, const std::string &op_name); AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index a9196c5c428659d2e4ccf7599e976da2dd4054b2..43857dddfd8a6bdfc9acb3ea93ba0b06ffbdde5d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP // insert trans if (origin_format != cur_format && cur_shape.size() > 1) { auto kernel_select = std::make_shared(); - final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, - kTransDataOpName, false); + final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node); final_index = 0; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc index 2c77794b145a9f75f4a133e4876ae00b5a6b4227..035455db5e5803f3a3768d48c5e2da5132d700bb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc @@ -67,22 +67,30 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n // if output_format=default transdata need split transdata->transpose else transpose->transdata if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { // trans input_format to hwcn - new_transdata_node = - AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true); + new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, + false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), + new_transdata_node); // trans hwcn to default_format - new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN, - output_format, prim::kPrimTranspose->name(), false); + new_transpose_node = + NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), + new_transpose_node); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); new_replace_node = new_transpose_node; } else { // trans default to hwcn - new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, - prim::kPrimTranspose->name(), true); + new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, + false, prim::kPrimTranspose->name()); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), + new_transpose_node); // trans hwcn to output_format - new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN, - output_format, kTransDataOpName, false); + new_transdata_node = + NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), + new_transpose_node); new_replace_node = new_transdata_node; } FuncGraphManagerPtr manager = func_graph->manager(); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index c6b84d57ad6fd96d277d2a942f65abf865b73265..6e2f4a8a2ed34ae99bea34e95b927be87cdb57c0 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -196,10 +196,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { } if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { kernel_info->SetFeatureMapFlag(true); - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode); + } + if (AnfAlgo::IsRealCNodeKernel(cnode)) { + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); - } else { - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode); } cnode->set_kernel_info(kernel_info); AnfAlgo::SetGraphId(graph_id_, cnode.get()); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8eacd10be0e591b3c9c2c13c1de21641009dc2bf..e829a49d793ceff19827279031fe78a779a9538d 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -151,7 +151,7 @@ constexpr auto kSquareSumAllOpName = "SquareSumAll"; // attr key name constexpr auto kAttrInputNames = "input_names"; -constexpr auto kAttrIsAICPUKernel = "is_ai_cpu_kernel"; +constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel"; constexpr auto kIsBackendCast = "is_backed_cast"; constexpr auto kAttrOutputNames = "output_names"; constexpr auto kAttrVisited = "visited";