提交 921ffccd 编写于 作者: L liubuyu

add aicpu kernel info select

上级 af7c54b1
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
......@@ -95,33 +94,53 @@ enum DataTypeTransMode {
FROM_FLOAT_TO_INT32,
FROM_FLOAT16_TO_FLOAT,
FROM_FLOAT16_TO_INT32,
FROM_FLOAT16_TO_UINT8,
FROM_INT32_TO_FLOAT,
FROM_INT32_TO_FLOAT16,
FROM_INT32_TO_UINT8,
FROM_INT32_TO_INT8,
FROM_INT32_TO_BOOL,
FROM_UINT8_TO_FLOAT,
FROM_UINT8_TO_INT32,
FROM_UINT8_TO_FLOAT16,
FROM_INT8_TO_FLOAT,
FROM_INT8_TO_FLOAT16,
FROM_INT8_TO_INT32,
FROM_INT64_TO_INT32,
FROM_UINT16_TO_INT32,
FROM_BOOL_TO_FLOAT,
FROM_BOOL_TO_INT32,
FROM_BOOL_TO_UINT8,
FROM_BOOL_TO_FLOAT16,
FROM_FLOAT64_TO_FLOAT32,
FROM_FLOAT32_TO_FLOAT64
};
const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
{std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32},
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64},
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16},
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT},
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8},
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT},
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16},
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8},
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8},
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL},
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT},
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16},
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16},
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}};
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32},
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT},
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8},
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
void CheckMemSize(const TypeIdArgs &args) {
auto src_type_size = TypeIdSize(args.host_data_type);
......@@ -154,54 +173,46 @@ void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size
}
bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
switch (mode) {
case FROM_FLOAT_TO_FLOAT16:
device::FloatToHalf(dst, args.data, data_size);
break;
case FROM_INT32_TO_FLOAT16:
TransDataSrc2Fp16<int32_t>(args, dst, data_size);
break;
case FROM_FLOAT16_TO_FLOAT:
device::HalfToFloat(dst, args.data, data_size);
break;
case FROM_FLOAT_TO_INT32:
TransDataSrc2Dst<float, int32_t>(args, dst, data_size);
break;
case FROM_FLOAT16_TO_INT32:
TransDataSrc2Dst<float16, int32_t>(args, dst, data_size);
break;
case FROM_INT32_TO_FLOAT:
TransDataSrc2Dst<int32_t, float>(args, dst, data_size);
break;
case FROM_INT32_TO_INT8:
TransDataSrc2Dst<int32_t, int8_t>(args, dst, data_size);
break;
case FROM_INT32_TO_UINT8:
TransDataSrc2Dst<int32_t, uint8_t>(args, dst, data_size);
break;
case FROM_UINT8_TO_INT32:
TransDataSrc2Dst<uint8_t, int32_t>(args, dst, data_size);
break;
case FROM_UINT8_TO_FLOAT:
TransDataSrc2Dst<uint8_t, float>(args, dst, data_size);
break;
case FROM_INT8_TO_FLOAT:
TransDataSrc2Dst<int8_t, float>(args, dst, data_size);
break;
case FROM_INT8_TO_INT32:
TransDataSrc2Dst<int8_t, int32_t>(args, dst, data_size);
break;
case FROM_INT64_TO_INT32:
TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size);
break;
case FROM_UINT16_TO_INT32:
TransDataSrc2Dst<uint16_t, int32_t>(args, dst, data_size);
break;
default:
MS_LOG(ERROR) << "Unsupported datatype trans";
return false;
using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
{FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
{FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
{FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
{FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
{FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
{FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
{FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
{FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
{FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
{FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
{FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
{FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
{FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
{FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
{FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
{FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
{FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
{FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
{FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
{FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
{FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
{FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
if (mode == FROM_FLOAT_TO_FLOAT16) {
device::FloatToHalf(dst, args.data, data_size);
return true;
} else if (mode == FROM_FLOAT16_TO_FLOAT) {
device::HalfToFloat(dst, args.data, data_size);
return true;
}
auto iter = cast_kernel_map.find(mode);
if (iter != cast_kernel_map.end()) {
iter->second(args, dst, data_size);
return true;
} else {
MS_LOG(ERROR) << "Unsupported datatype trans";
return false;
}
return true;
}
size_t CubeSizeByType(const TypeId data_type) {
......
......@@ -464,14 +464,12 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
} // namespace
int SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
int status = kStatusAllMatched;
std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo(
int *status, const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
bool precision_reduce = false;
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
kernel::KernelQuery(kernel_node, &kernel_info_list);
// filter kernel info matched with me infered type
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
if (!filtered_kernel_info_list.empty()) {
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
......@@ -481,15 +479,34 @@ int SelectKernelInfo(const CNodePtr &kernel_node) {
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
if (selected_kernel_info == nullptr) {
std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node);
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
return nullptr;
} else {
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
*status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
}
}
return selected_kernel_info;
}
int SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
int status = kStatusAllMatched;
MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list);
// filter kernel info matched with me infered type
auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
if (selected_kernel_info == nullptr) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AicpuQuery(kernel_node, &kernel_info_list);
selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
}
if (selected_kernel_info == nullptr) {
std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node);
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type " << buffer.str();
}
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
......
......@@ -67,5 +67,13 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
}
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
kernel_info_list->clear();
AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
} // namespace kernel
} // namespace mindspore
......@@ -26,6 +26,7 @@
namespace mindspore {
namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册