diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index e977084ab80d18da8c9c0a67e56965ed6eb09760..1fd3096e7c59ce7fdecab9d26e4fea8ab3a71fba 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -91,6 +91,14 @@ void PrintNodeInputType(std::ostringstream &buffer, const AnfNodePtr &nd) { } } +void PrintInputAndOutputInferType(std::ostringstream &buffer, const AnfNodePtr &nd) { + buffer << " : ("; + PrintNodeInputType(buffer, nd); + buffer << ") -> ("; + PrintNodeOutputType(buffer, nd); + buffer << ")"; +} + struct SubGraphIRInfo { int32_t local_var; std::ostringstream buffer; diff --git a/mindspore/ccsrc/debug/anf_ir_dump.h b/mindspore/ccsrc/debug/anf_ir_dump.h index a53888348d01853123a0567cbf724100a3ac0228..9fa447046f3a96270084020d19a6daa2e4c82bb3 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.h +++ b/mindspore/ccsrc/debug/anf_ir_dump.h @@ -18,12 +18,14 @@ #include #include +#include "ir/dtype/type.h" #include "ir/anf.h" namespace mindspore { constexpr char PARALLEL_STRATEGY[] = "strategy"; void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); - +void PrintInputAndOutputInferType(std::ostringstream &buffer, const AnfNodePtr &nd); +const std::string ToShortString(const TypeId &typeId); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_DUMP_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 36c622cbc5218a3958f3c880fce2285e450c404a..549b97b61bed7ad15b784a7c0174a647c4342f16 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -18,14 +18,15 @@ #include #include #include -#include -#include +#include +#include #include "kernel/oplib/oplib.h" #include "kernel/kernel_query.h" #include "session/anf_runtime_algorithm.h" #include "kernel/kernel_build_info.h" #include "utils/context/ms_context.h" #include "operator/ops.h" +#include "debug/anf_ir_dump.h" namespace mindspore { namespace device { @@ -180,6 +181,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co } void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *support_index) { + MS_EXCEPTION_IF_NULL(support_index); int index = kUnSupportMixedDataTypeIndex; switch (data_type) { case kNumberTypeFloat16: @@ -197,6 +199,7 @@ void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *s void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, std::vector *support_datatype_index, std::vector *support_datatype) { + MS_EXCEPTION_IF_NULL(support_datatype); auto data_type = kernel_build_info.GetInputDeviceType(input_index); support_datatype->push_back(data_type); AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); @@ -204,6 +207,7 @@ void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_i void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, std::vector *support_datatype_index, std::vector *support_datatype) { + MS_EXCEPTION_IF_NULL(support_datatype); auto data_type = kernel_build_info.GetOutputDeviceType(output_index); support_datatype->push_back(data_type); AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); @@ -238,8 +242,8 @@ void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_index, const std::vector &node_mix_precision_datatype, - const std::unordered_map> &kernel_support_datatypes, - std::unordered_map> *kernel_match_datatype_idx) { + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " << node_mix_precision_datatype.size(); @@ -251,10 +255,11 @@ void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_ind } } -int RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::unordered_map> &kernel_support_datatypes, - std::unordered_map> *kernel_match_datatype_idx) { +bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, kernel_match_datatype_idx); for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { @@ -289,40 +294,16 @@ int RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_data } } } - - if (kernel_match_datatype_idx->size() >= 1) { - return SizeToInt(kernel_match_datatype_idx->begin()->first); - } - return -1; -} - -int GetMinReducePrecisionCountIndex(std::unordered_map> *kernel_match_datatype_idx, - const std::unordered_map &precision_reduce_count) { - int selected_index = -1; - size_t min_reduce_precision_count = kMaxCount; - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - auto find_iter = precision_reduce_count.find(iter->first); - if (find_iter == precision_reduce_count.end()) { - continue; - } - if (min_reduce_precision_count > find_iter->second) { - selected_index = SizeToInt(iter->first); - min_reduce_precision_count = find_iter->second; - } - ++iter; - } - return selected_index; + return !kernel_match_datatype_idx->empty(); } -int RaiseOrReduceDataTypePrecisionSelect( - const std::vector &node_mix_precision_datatype_index, const std::vector &node_mix_precision_datatype, - const std::unordered_map> &kernel_support_datatypes, - std::unordered_map> *kernel_match_datatype_idx) { +bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, kernel_match_datatype_idx); - // reduce / raise - std::unordered_map precision_reduce_count; for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { if (node_mix_precision_datatype[i] == kTypeUnknown) { continue; @@ -351,26 +332,18 @@ int RaiseOrReduceDataTypePrecisionSelect( if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) { iter = kernel_match_datatype_idx->erase(iter); } else { - if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { - auto count_iter = precision_reduce_count.find(iter->first); - if (count_iter != precision_reduce_count.end()) { - count_iter->second++; - } else { - precision_reduce_count[iter->first] = 1; - } - } ++iter; } } } - - return GetMinReducePrecisionCountIndex(kernel_match_datatype_idx, precision_reduce_count); + return !kernel_match_datatype_idx->empty(); } void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, std::vector *support_indexes, std::vector *node_mix_precision_datatype, std::vector *support_datatypes, std::vector *node_mix_precision_datatype_index) { + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); bool add_node_datatype_flag = false; if (node_mix_precision_datatype->size() == 0) { add_node_datatype_flag = true; @@ -390,104 +363,58 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB } } -int PrecisionReduce(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::unordered_map> &kernel_support_datatype, - std::unordered_map> *kernel_match_datatype_idx, bool *precision_reduce) { +void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatype, + std::map> *kernel_match_datatype_idx, bool *precision_reduce) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(precision_reduce); - std::unordered_map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; + std::map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; // raise precision - int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, kernel_match_datatype_idx); - if (selected_index != -1) { - int max_match = 0; - auto iter = kernel_match_datatype_idx->begin(); - int match_count = 0; - while (iter != kernel_match_datatype_idx->end()) { - auto kernel_datatypes = kernel_support_datatype.find(iter->first); - if (kernel_datatypes == kernel_support_datatype.end()) { - MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype."; - } - if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) { - MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!"; - } - for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) { - if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) { - ++match_count; - } - } - if (match_count > max_match) { - selected_index = SizeToInt(iter->first); - } - ++iter; - } + bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, + kernel_support_datatype, kernel_match_datatype_idx); + if (selected_ret) { + return; } - if (selected_index == -1 && context_ptr->enable_reduce_precision()) { - selected_index = - RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, &kernel_match_datatype_idx_copy); - if (selected_index != -1) { - *precision_reduce = true; - } + if (context_ptr->enable_reduce_precision()) { + selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, + kernel_support_datatype, &kernel_match_datatype_idx_copy); + } + if (selected_ret) { + *precision_reduce = true; + *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; } - return selected_index; } -void SelectKernel(const CNodePtr &kernel_node, bool precision_reduce, const std::vector &node_datatype, - const std::shared_ptr &selected_kernel_info_ptr) { - MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); +void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, + const std::shared_ptr &selected_kernel_build_info, + bool precision_reduce) { + MS_EXCEPTION_IF_NULL(selected_kernel_build_info); + MS_EXCEPTION_IF_NULL(cnode); + std::ostringstream buffer; + buffer << cnode->DebugString(); if (precision_reduce) { - std::ostringstream datatype; - size_t input_num = selected_kernel_info_ptr->GetInputNum(); - size_t i = 0; - datatype << "("; - for (; i < input_num && i < node_datatype.size(); ++i) { - datatype << static_cast(node_datatype[i]); - if (i < input_num - 1) { - datatype << ", "; - } - } - datatype << ") -> ("; - for (; i < node_datatype.size(); ++i) { - datatype << static_cast(node_datatype[i]); - if (i < node_datatype.size() - 1) { - datatype << ", "; - } - } - datatype << ")"; - MS_LOG(WARNING) << kernel_node->DebugString() << " reduce precision, node datatype: " << datatype.str() - << ", select kernel: %s" << selected_kernel_info_ptr->ToString(); + buffer << " reduce precision, node datatype: "; + } else { + buffer << " raise precision, node datatype: "; } - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get()); - // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info_ptr, kernel_node); + PrintInputAndOutputInferType(buffer, cnode); + buffer << ", select kernel:" << selected_kernel_build_info->ToString(); + MS_LOG(INFO) << buffer.str(); } -} // namespace -void SelectKernelInfo(const CNodePtr &kernel_node) { - std::vector> kernel_info_list; - MS_EXCEPTION_IF_NULL(kernel_node); - kernel::KernelQuery(kernel_node, &kernel_info_list); +std::shared_ptr ChooseMatchedKernelInfo( + const CNodePtr &kernel_node, const std::vector> &kernel_info_list) { + if (kernel_info_list.empty()) { + return nullptr; + } std::vector most_match_counts = {-1, -1, -1, -1}; - int selected_index = -1; - std::unordered_map> kernel_match_datatype_idx; - std::unordered_map> kernel_support_datatype; - std::vector node_mix_precision_datatype_index; - std::vector node_mix_precision_datatype; + size_t selected_index = 0; for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { std::vector cur_kernel_info_match_counts = {0, 0, 0, 0}; auto kernel_build_info = *(kernel_info_list[info_index]); - std::vector support_indexes; - std::vector support_datatypes; - AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype, - &support_datatypes, &node_mix_precision_datatype_index); - kernel_match_datatype_idx[info_index] = support_indexes; - kernel_support_datatype[info_index] = support_datatypes; - if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) { - continue; - } std::shared_ptr kernel_info_ptr = kernel_info_list[info_index]; UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); // Currently the selection policy is the match format count first, and then is datatype counts. @@ -495,22 +422,77 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { selected_index = SizeToInt(info_index); } } + return kernel_info_list[selected_index]; +} - bool precision_reduce = false; - if (selected_index == -1) { - selected_index = PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, &kernel_match_datatype_idx, &precision_reduce); +std::vector> GetAllMatchedFilteredKernelInfo( + const CNodePtr &cnode, const std::vector> &kernel_info_list) { + std::vector> result; + for (const auto &kernel_build_info : kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_build_info); + if (!MatchInferOutputDataType(cnode, *kernel_build_info)) { + continue; + } + result.push_back(kernel_build_info); } - if (selected_index == -1) { - MS_LOG(EXCEPTION) << kernel_node->DebugString() << "Cannot find valid kernel Info !"; + return result; +} + +std::vector> FilterRaisedOrReducePrecisionMatchedKernelInfo( + const CNodePtr &cnode, const std::vector> &kernel_info_list, + bool *precision_reduce) { + std::vector> filtered_kernel_info_list; + std::map> kernel_match_datatype_idx; + std::map> kernel_support_datatype; + std::vector node_mix_precision_datatype_index; + std::vector node_mix_precision_datatype; + for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { + std::vector support_indexes; + std::vector support_datatypes; + MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); + AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, + &support_datatypes, &node_mix_precision_datatype_index); + kernel_match_datatype_idx[info_index] = support_indexes; + kernel_support_datatype[info_index] = support_datatypes; } - auto index = IntToSize(selected_index); - if (index >= kernel_info_list.size()) { - MS_LOG(EXCEPTION) << "index outof range"; + PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, + &kernel_match_datatype_idx, precision_reduce); + std::transform( + kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), + [&](const std::pair> &matched_idx) -> std::shared_ptr { + return kernel_info_list[matched_idx.first]; + }); + return filtered_kernel_info_list; +} +} // namespace + +void SelectKernelInfo(const CNodePtr &kernel_node) { + std::vector> kernel_info_list; + MS_EXCEPTION_IF_NULL(kernel_node); + bool precision_reduce = false; + std::shared_ptr selected_kernel_info = nullptr; + kernel::KernelQuery(kernel_node, &kernel_info_list); + // filter kernel info matched with me infered type + auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); + if (!filtered_kernel_info_list.empty()) { + selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + } else { + // selected kernel info using raised precision or reduce precision + filtered_kernel_info_list = + FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); + selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + if (selected_kernel_info == nullptr) { + std::ostringstream buffer; + PrintInputAndOutputInferType(buffer, kernel_node); + MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() + << "] cannot find valid kernel info, not supported the type" << buffer.str(); + } else { + PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); + } } - std::shared_ptr selected_kernel_info_ptr = kernel_info_list[index]; - MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); - SelectKernel(kernel_node, precision_reduce, node_mix_precision_datatype, selected_kernel_info_ptr); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); } bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index 038c06d8edcbb12b274cc8421ad36067cee5d733..279a62bad6d6d811059ea68a45024872b364cd65 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -17,6 +17,7 @@ #include "kernel/kernel_build_info.h" #include #include "utils/log_adapter.h" +#include "debug/anf_ir_dump.h" namespace mindspore { namespace kernel { std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { @@ -82,14 +83,14 @@ std::string KernelBuildInfo::ToString() const { if (index != 0) { output_buffer << ", "; } - output_buffer << "<" << static_cast(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; + output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; } output_buffer << ") -> ("; for (size_t index = 0; index < GetOutputNum(); ++index) { if (index != 0) { output_buffer << ", "; } - output_buffer << "<" << static_cast(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; + output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; } output_buffer << ")"; return output_buffer.str();