diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 922f62329d45617b1dfbbe9ca53400d8a1b25b45..42b5824ed53c79d025bd33f4f54ebea0c681b59b 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -562,10 +562,17 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index) << "] :" << aicpu_kernel_info_list[index]->ToString(); } - MS_LOG(WARNING) << " <<<"; - MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() - << "] cannot find valid kernel info, not supported the type:" << buffer.str() - << ", please refer to the supported dtypes in candidates kernel info list"; + if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) { + auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + } else { + MS_LOG(WARNING) << " <<<"; + MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() + << "] cannot find valid kernel info, not supported the type:" << buffer.str() + << ", please refer to the supported dtypes in candidates kernel info list"; + } } return select_status; } diff --git a/mindspore/ccsrc/kernel/rts/label_switch.cc b/mindspore/ccsrc/kernel/rts/label_switch.cc index d84407a930b228f4790d05007453946b77dd5ea8..fb1ad1601a4757240c7aa03cd1efe2453eb2f8d4 100644 --- a/mindspore/ccsrc/kernel/rts/label_switch.cc +++ b/mindspore/ccsrc/kernel/rts/label_switch.cc @@ -75,8 +75,8 @@ std::vector LabelSwitchKernel::GenTask(const std::vector> LabelSwitchDesc::GetKernelInfo() { std::vector> label_switch_build_info{}; - vector input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT}; - vector input_type{kNumberTypeUInt32, kNumberTypeBool}; + vector input_format{kOpFormat_DEFAULT}; + vector input_type{kNumberTypeInt32}; if (input_format.size() != input_type.size()) { MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " << input_type.size();