提交 a04e8486 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!650 Match format when kernel selecting using raise or reduce precision

Merge pull request !650 from liubuyu/r0.2
...@@ -91,6 +91,14 @@ void PrintNodeInputType(std::ostringstream &buffer, const AnfNodePtr &nd) { ...@@ -91,6 +91,14 @@ void PrintNodeInputType(std::ostringstream &buffer, const AnfNodePtr &nd) {
} }
} }
void PrintInputAndOutputInferType(std::ostringstream &buffer, const AnfNodePtr &nd) {
buffer << " : (";
PrintNodeInputType(buffer, nd);
buffer << ") -> (";
PrintNodeOutputType(buffer, nd);
buffer << ")";
}
struct SubGraphIRInfo { struct SubGraphIRInfo {
int32_t local_var; int32_t local_var;
std::ostringstream buffer; std::ostringstream buffer;
......
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ir/dtype/type.h"
#include "ir/anf.h" #include "ir/anf.h"
namespace mindspore { namespace mindspore {
constexpr char PARALLEL_STRATEGY[] = "strategy"; constexpr char PARALLEL_STRATEGY[] = "strategy";
void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);
void PrintInputAndOutputInferType(std::ostringstream &buffer, const AnfNodePtr &nd);
const std::string ToShortString(const TypeId &typeId);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_DUMP_H_ #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_DUMP_H_
...@@ -18,14 +18,15 @@ ...@@ -18,14 +18,15 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <set> #include <utility>
#include <unordered_map> #include <map>
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "kernel/kernel_query.h" #include "kernel/kernel_query.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "debug/anf_ir_dump.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
...@@ -180,6 +181,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co ...@@ -180,6 +181,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
} }
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
MS_EXCEPTION_IF_NULL(support_index);
int index = kUnSupportMixedDataTypeIndex; int index = kUnSupportMixedDataTypeIndex;
switch (data_type) { switch (data_type) {
case kNumberTypeFloat16: case kNumberTypeFloat16:
...@@ -197,6 +199,7 @@ void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *s ...@@ -197,6 +199,7 @@ void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *s
void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index,
std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) { std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
MS_EXCEPTION_IF_NULL(support_datatype);
auto data_type = kernel_build_info.GetInputDeviceType(input_index); auto data_type = kernel_build_info.GetInputDeviceType(input_index);
support_datatype->push_back(data_type); support_datatype->push_back(data_type);
AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
...@@ -204,6 +207,7 @@ void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_i ...@@ -204,6 +207,7 @@ void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_i
void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index,
std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) { std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
MS_EXCEPTION_IF_NULL(support_datatype);
auto data_type = kernel_build_info.GetOutputDeviceType(output_index); auto data_type = kernel_build_info.GetOutputDeviceType(output_index);
support_datatype->push_back(data_type); support_datatype->push_back(data_type);
AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
...@@ -238,8 +242,8 @@ void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, ...@@ -238,8 +242,8 @@ void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index, void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index,
const std::vector<TypeId> &node_mix_precision_datatype, const std::vector<TypeId> &node_mix_precision_datatype,
const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatypes, const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx) { std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) {
MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
<< node_mix_precision_datatype.size(); << node_mix_precision_datatype.size();
...@@ -251,10 +255,11 @@ void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_ind ...@@ -251,10 +255,11 @@ void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_ind
} }
} }
int RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index, bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
const std::vector<TypeId> &node_mix_precision_datatype, const std::vector<TypeId> &node_mix_precision_datatype,
const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatypes, const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx) { std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
kernel_match_datatype_idx); kernel_match_datatype_idx);
for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
...@@ -289,40 +294,16 @@ int RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_data ...@@ -289,40 +294,16 @@ int RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_data
} }
} }
} }
return !kernel_match_datatype_idx->empty();
if (kernel_match_datatype_idx->size() >= 1) {
return SizeToInt(kernel_match_datatype_idx->begin()->first);
}
return -1;
}
int GetMinReducePrecisionCountIndex(std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx,
const std::unordered_map<size_t, size_t> &precision_reduce_count) {
int selected_index = -1;
size_t min_reduce_precision_count = kMaxCount;
auto iter = kernel_match_datatype_idx->begin();
while (iter != kernel_match_datatype_idx->end()) {
auto find_iter = precision_reduce_count.find(iter->first);
if (find_iter == precision_reduce_count.end()) {
continue;
}
if (min_reduce_precision_count > find_iter->second) {
selected_index = SizeToInt(iter->first);
min_reduce_precision_count = find_iter->second;
}
++iter;
}
return selected_index;
} }
int RaiseOrReduceDataTypePrecisionSelect( bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
const std::vector<int> &node_mix_precision_datatype_index, const std::vector<TypeId> &node_mix_precision_datatype, const std::vector<TypeId> &node_mix_precision_datatype,
const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatypes, const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx) { std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
kernel_match_datatype_idx); kernel_match_datatype_idx);
// reduce / raise
std::unordered_map<size_t, size_t> precision_reduce_count;
for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
if (node_mix_precision_datatype[i] == kTypeUnknown) { if (node_mix_precision_datatype[i] == kTypeUnknown) {
continue; continue;
...@@ -351,26 +332,18 @@ int RaiseOrReduceDataTypePrecisionSelect( ...@@ -351,26 +332,18 @@ int RaiseOrReduceDataTypePrecisionSelect(
if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) { if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) {
iter = kernel_match_datatype_idx->erase(iter); iter = kernel_match_datatype_idx->erase(iter);
} else { } else {
if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
auto count_iter = precision_reduce_count.find(iter->first);
if (count_iter != precision_reduce_count.end()) {
count_iter->second++;
} else {
precision_reduce_count[iter->first] = 1;
}
}
++iter; ++iter;
} }
} }
} }
return !kernel_match_datatype_idx->empty();
return GetMinReducePrecisionCountIndex(kernel_match_datatype_idx, precision_reduce_count);
} }
void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info,
std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype, std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype,
std::vector<TypeId> *support_datatypes, std::vector<TypeId> *support_datatypes,
std::vector<int> *node_mix_precision_datatype_index) { std::vector<int> *node_mix_precision_datatype_index) {
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
bool add_node_datatype_flag = false; bool add_node_datatype_flag = false;
if (node_mix_precision_datatype->size() == 0) { if (node_mix_precision_datatype->size() == 0) {
add_node_datatype_flag = true; add_node_datatype_flag = true;
...@@ -390,104 +363,58 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB ...@@ -390,104 +363,58 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB
} }
} }
int PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index, void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
const std::vector<TypeId> &node_mix_precision_datatype, const std::vector<TypeId> &node_mix_precision_datatype,
const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatype, const std::map<size_t, std::vector<TypeId>> &kernel_support_datatype,
std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) { std::map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) {
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(precision_reduce); MS_EXCEPTION_IF_NULL(precision_reduce);
std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; std::map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx;
// raise precision // raise precision
int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
kernel_support_datatype, kernel_match_datatype_idx); kernel_support_datatype, kernel_match_datatype_idx);
if (selected_index != -1) { if (selected_ret) {
int max_match = 0; return;
auto iter = kernel_match_datatype_idx->begin();
int match_count = 0;
while (iter != kernel_match_datatype_idx->end()) {
auto kernel_datatypes = kernel_support_datatype.find(iter->first);
if (kernel_datatypes == kernel_support_datatype.end()) {
MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype.";
}
if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) {
MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!";
}
for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) {
if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) {
++match_count;
}
}
if (match_count > max_match) {
selected_index = SizeToInt(iter->first);
}
++iter;
}
} }
if (selected_index == -1 && context_ptr->enable_reduce_precision()) { if (context_ptr->enable_reduce_precision()) {
selected_index = selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, &kernel_match_datatype_idx_copy);
kernel_support_datatype, &kernel_match_datatype_idx_copy); }
if (selected_index != -1) { if (selected_ret) {
*precision_reduce = true; *precision_reduce = true;
} *kernel_match_datatype_idx = kernel_match_datatype_idx_copy;
} }
return selected_index;
} }
void SelectKernel(const CNodePtr &kernel_node, bool precision_reduce, const std::vector<TypeId> &node_datatype, void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_info_ptr) { const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); bool precision_reduce) {
MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
MS_EXCEPTION_IF_NULL(cnode);
std::ostringstream buffer;
buffer << cnode->DebugString();
if (precision_reduce) { if (precision_reduce) {
std::ostringstream datatype; buffer << " reduce precision, node datatype: ";
size_t input_num = selected_kernel_info_ptr->GetInputNum(); } else {
size_t i = 0; buffer << " raise precision, node datatype: ";
datatype << "(";
for (; i < input_num && i < node_datatype.size(); ++i) {
datatype << static_cast<int>(node_datatype[i]);
if (i < input_num - 1) {
datatype << ", ";
}
}
datatype << ") -> (";
for (; i < node_datatype.size(); ++i) {
datatype << static_cast<int>(node_datatype[i]);
if (i < node_datatype.size() - 1) {
datatype << ", ";
}
}
datatype << ")";
MS_LOG(WARNING) << kernel_node->DebugString() << " reduce precision, node datatype: " << datatype.str()
<< ", select kernel: %s" << selected_kernel_info_ptr->ToString();
} }
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get()); PrintInputAndOutputInferType(buffer, cnode);
// Set format and data type for input tensor. buffer << ", select kernel:" << selected_kernel_build_info->ToString();
SetTensorDeviceInfo(*selected_kernel_info_ptr, kernel_node); MS_LOG(INFO) << buffer.str();
} }
} // namespace
void SelectKernelInfo(const CNodePtr &kernel_node) { std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; const CNodePtr &kernel_node, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); if (kernel_info_list.empty()) {
kernel::KernelQuery(kernel_node, &kernel_info_list); return nullptr;
}
std::vector<int> most_match_counts = {-1, -1, -1, -1}; std::vector<int> most_match_counts = {-1, -1, -1, -1};
int selected_index = -1; size_t selected_index = 0;
std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx;
std::unordered_map<size_t, std::vector<TypeId>> kernel_support_datatype;
std::vector<int> node_mix_precision_datatype_index;
std::vector<TypeId> node_mix_precision_datatype;
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0};
auto kernel_build_info = *(kernel_info_list[info_index]); auto kernel_build_info = *(kernel_info_list[info_index]);
std::vector<int> support_indexes;
std::vector<TypeId> support_datatypes;
AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype,
&support_datatypes, &node_mix_precision_datatype_index);
kernel_match_datatype_idx[info_index] = support_indexes;
kernel_support_datatype[info_index] = support_datatypes;
if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
// Currently the selection policy is the match format count first, and then is datatype counts. // Currently the selection policy is the match format count first, and then is datatype counts.
...@@ -495,22 +422,77 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { ...@@ -495,22 +422,77 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
selected_index = SizeToInt(info_index); selected_index = SizeToInt(info_index);
} }
} }
return kernel_info_list[selected_index];
}
bool precision_reduce = false; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetAllMatchedFilteredKernelInfo(
if (selected_index == -1) { const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
selected_index = PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
kernel_support_datatype, &kernel_match_datatype_idx, &precision_reduce); for (const auto &kernel_build_info : kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_build_info);
if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
continue;
}
result.push_back(kernel_build_info);
} }
if (selected_index == -1) { return result;
MS_LOG(EXCEPTION) << kernel_node->DebugString() << "Cannot find valid kernel Info !"; }
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
bool *precision_reduce) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list;
std::map<size_t, std::vector<int>> kernel_match_datatype_idx;
std::map<size_t, std::vector<TypeId>> kernel_support_datatype;
std::vector<int> node_mix_precision_datatype_index;
std::vector<TypeId> node_mix_precision_datatype;
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
std::vector<int> support_indexes;
std::vector<TypeId> support_datatypes;
MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype,
&support_datatypes, &node_mix_precision_datatype_index);
kernel_match_datatype_idx[info_index] = support_indexes;
kernel_support_datatype[info_index] = support_datatypes;
} }
auto index = IntToSize(selected_index); PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype,
if (index >= kernel_info_list.size()) { &kernel_match_datatype_idx, precision_reduce);
MS_LOG(EXCEPTION) << "index outof range"; std::transform(
kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list),
[&](const std::pair<size_t, std::vector<int>> &matched_idx) -> std::shared_ptr<kernel::KernelBuildInfo> {
return kernel_info_list[matched_idx.first];
});
return filtered_kernel_info_list;
}
} // namespace
void SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
MS_EXCEPTION_IF_NULL(kernel_node);
bool precision_reduce = false;
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
kernel::KernelQuery(kernel_node, &kernel_info_list);
// filter kernel info matched with me infered type
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
if (!filtered_kernel_info_list.empty()) {
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
} else {
// selected kernel info using raised precision or reduce precision
filtered_kernel_info_list =
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
if (selected_kernel_info == nullptr) {
std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node);
MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
} else {
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
}
} }
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index]; AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); // Set format and data type for input tensor.
SelectKernel(kernel_node, precision_reduce, node_mix_precision_datatype, selected_kernel_info_ptr); SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
} }
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
#include <algorithm> #include <algorithm>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "debug/anf_ir_dump.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
...@@ -82,14 +83,14 @@ std::string KernelBuildInfo::ToString() const { ...@@ -82,14 +83,14 @@ std::string KernelBuildInfo::ToString() const {
if (index != 0) { if (index != 0) {
output_buffer << ", "; output_buffer << ", ";
} }
output_buffer << "<" << static_cast<int>(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
} }
output_buffer << ") -> ("; output_buffer << ") -> (";
for (size_t index = 0; index < GetOutputNum(); ++index) { for (size_t index = 0; index < GetOutputNum(); ++index) {
if (index != 0) { if (index != 0) {
output_buffer << ", "; output_buffer << ", ";
} }
output_buffer << "<" << static_cast<int>(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
} }
output_buffer << ")"; output_buffer << ")";
return output_buffer.str(); return output_buffer.str();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册