diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index ad018c35355ef7791c2a48467f3288daba3d2630..225ec05196fc32e2cbd8ae6401ef2c6d8bae9a64 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -1,4 +1,3 @@ - /** * Copyright 2020 Huawei Technologies Co., Ltd * @@ -95,33 +94,53 @@ enum DataTypeTransMode { FROM_FLOAT_TO_INT32, FROM_FLOAT16_TO_FLOAT, FROM_FLOAT16_TO_INT32, + FROM_FLOAT16_TO_UINT8, FROM_INT32_TO_FLOAT, FROM_INT32_TO_FLOAT16, FROM_INT32_TO_UINT8, FROM_INT32_TO_INT8, + FROM_INT32_TO_BOOL, FROM_UINT8_TO_FLOAT, FROM_UINT8_TO_INT32, + FROM_UINT8_TO_FLOAT16, FROM_INT8_TO_FLOAT, + FROM_INT8_TO_FLOAT16, FROM_INT8_TO_INT32, FROM_INT64_TO_INT32, FROM_UINT16_TO_INT32, + FROM_BOOL_TO_FLOAT, + FROM_BOOL_TO_INT32, + FROM_BOOL_TO_UINT8, + FROM_BOOL_TO_FLOAT16, + FROM_FLOAT64_TO_FLOAT32, + FROM_FLOAT32_TO_FLOAT64 }; const std::map, DataTypeTransMode> mode_map{ + {std::pair(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32}, + {std::pair(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64}, {std::pair(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16}, {std::pair(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32}, {std::pair(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT}, {std::pair(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32}, + {std::pair(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8}, {std::pair(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT}, {std::pair(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16}, {std::pair(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8}, {std::pair(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8}, + {std::pair(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL}, {std::pair(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT}, {std::pair(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32}, + {std::pair(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16}, {std::pair(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT}, + {std::pair(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16}, {std::pair(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32}, {std::pair(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}, - {std::pair(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}}; + {std::pair(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}, + {std::pair(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32}, + {std::pair(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT}, + {std::pair(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8}, + {std::pair(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}}; void CheckMemSize(const TypeIdArgs &args) { auto src_type_size = TypeIdSize(args.host_data_type); @@ -154,54 +173,46 @@ void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size } bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) { - switch (mode) { - case FROM_FLOAT_TO_FLOAT16: - device::FloatToHalf(dst, args.data, data_size); - break; - case FROM_INT32_TO_FLOAT16: - TransDataSrc2Fp16(args, dst, data_size); - break; - case FROM_FLOAT16_TO_FLOAT: - device::HalfToFloat(dst, args.data, data_size); - break; - case FROM_FLOAT_TO_INT32: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_FLOAT16_TO_INT32: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_INT32_TO_FLOAT: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_INT32_TO_INT8: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_INT32_TO_UINT8: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_UINT8_TO_INT32: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_UINT8_TO_FLOAT: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_INT8_TO_FLOAT: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_INT8_TO_INT32: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_INT64_TO_INT32: - TransDataSrc2Dst(args, dst, data_size); - break; - case FROM_UINT16_TO_INT32: - TransDataSrc2Dst(args, dst, data_size); - break; - default: - MS_LOG(ERROR) << "Unsupported datatype trans"; - return false; + using DtypeKernel = std::function; + const std::map cast_kernel_map{ + {FROM_FLOAT_TO_INT32, TransDataSrc2Dst}, + {FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst}, + {FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst}, + {FROM_FLOAT16_TO_INT32, TransDataSrc2Dst}, + {FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst}, + {FROM_INT32_TO_FLOAT, TransDataSrc2Dst}, + {FROM_INT32_TO_INT8, TransDataSrc2Dst}, + {FROM_INT32_TO_UINT8, TransDataSrc2Dst}, + {FROM_INT32_TO_BOOL, TransDataSrc2Dst}, + {FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16}, + {FROM_UINT8_TO_FLOAT, TransDataSrc2Dst}, + {FROM_UINT8_TO_INT32, TransDataSrc2Dst}, + {FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16}, + {FROM_INT8_TO_FLOAT, TransDataSrc2Dst}, + {FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16}, + {FROM_INT8_TO_INT32, TransDataSrc2Dst}, + {FROM_INT64_TO_INT32, TransDataSrc2Dst}, + {FROM_UINT16_TO_INT32, TransDataSrc2Dst}, + {FROM_BOOL_TO_INT32, TransDataSrc2Dst}, + {FROM_BOOL_TO_FLOAT, TransDataSrc2Dst}, + {FROM_BOOL_TO_UINT8, TransDataSrc2Dst}, + {FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16}}; + + if (mode == FROM_FLOAT_TO_FLOAT16) { + device::FloatToHalf(dst, args.data, data_size); + return true; + } else if (mode == FROM_FLOAT16_TO_FLOAT) { + device::HalfToFloat(dst, args.data, data_size); + return true; + } + auto iter = cast_kernel_map.find(mode); + if (iter != cast_kernel_map.end()) { + iter->second(args, dst, data_size); + return true; + } else { + MS_LOG(ERROR) << "Unsupported datatype trans"; + return false; } - return true; } size_t CubeSizeByType(const TypeId data_type) { diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 9e54adc635d9ff73138d1bda921f3871c8d15354..d615d261c741c10625704d11f9d48f5dd2a660f2 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -464,14 +464,12 @@ std::vector> FilterRaisedOrReducePrecis } } // namespace -int SelectKernelInfo(const CNodePtr &kernel_node) { - std::vector> kernel_info_list; - int status = kStatusAllMatched; +std::shared_ptr CanHitKernelInfo( + int *status, const CNodePtr &kernel_node, + const std::vector> &kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); bool precision_reduce = false; std::shared_ptr selected_kernel_info = nullptr; - kernel::KernelQuery(kernel_node, &kernel_info_list); - // filter kernel info matched with me infered type auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); if (!filtered_kernel_info_list.empty()) { selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); @@ -481,15 +479,34 @@ int SelectKernelInfo(const CNodePtr &kernel_node) { FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); if (selected_kernel_info == nullptr) { - std::ostringstream buffer; - PrintInputAndOutputInferType(buffer, kernel_node); - MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() - << "] cannot find valid kernel info, not supported the type" << buffer.str(); + return nullptr; } else { PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); - status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; + *status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; } } + return selected_kernel_info; +} + +int SelectKernelInfo(const CNodePtr &kernel_node) { + std::vector> kernel_info_list; + int status = kStatusAllMatched; + MS_EXCEPTION_IF_NULL(kernel_node); + kernel::KernelQuery(kernel_node, &kernel_info_list); + // filter kernel info matched with me infered type + auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); + if (selected_kernel_info == nullptr) { + MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + kernel::AicpuQuery(kernel_node, &kernel_info_list); + selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); + } + if (selected_kernel_info == nullptr) { + std::ostringstream buffer; + PrintInputAndOutputInferType(buffer, kernel_node); + MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() + << "] cannot find valid kernel info, not supported the type " << buffer.str(); + } AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); // Set format and data type for input tensor. SetTensorDeviceInfo(*selected_kernel_info, kernel_node); diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index e4a1af7f50dc21456fb8eae1de1714376a7c2f92..a2a5958a3f8b30f1d8f65f1b60fb6f0cbf4ffedf 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -67,5 +67,13 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + kernel_info_list->clear(); + AicpuMetadataInfo(kernel_node, kernel_info_list); + FilterInvalidKernelInfo(kernel_node, kernel_info_list); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_query.h b/mindspore/ccsrc/kernel/kernel_query.h index 72acbc0952541777810e23a63ee6c0d5771917a0..3e16b6b6129be0eb2cf2e27accaf44d756e7c804 100644 --- a/mindspore/ccsrc/kernel/kernel_query.h +++ b/mindspore/ccsrc/kernel/kernel_query.h @@ -26,6 +26,7 @@ namespace mindspore { namespace kernel { void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +void AicpuQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_