diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 5b0f02bed13826e6247d4368f472ad642608d14b..6e6e7419fd08cec450cfeccac193eda6d301381d 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -162,8 +162,18 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co } std::shared_ptr builder = std::make_shared(); + bool is_ref = false; + auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); + if (op_info != nullptr) { + is_ref = op_info->is_ref(); + } + MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); + if (MsContext::GetInstance()->execution_mode() == kPynativeMode && + AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { + continue; + } // we set special device info of a input tensor. - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) { + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};