diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index c912a0c199b30d5b5090e78630c947637a0b9331..bb7ce75ac46f78f4aa52477ff6e8532c29ff43d7 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -119,6 +119,8 @@ bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_ bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } +bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); } + void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->kernel_type_ = kernel_type; diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index ca1083fd688f3db9fb37dcab653fd8b32d2f05a8..45ac45f98f0ca677308b1ca0f3b535c64f97e644 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -85,6 +85,8 @@ class KernelBuildInfo { bool operator==(const KernelBuildInfo &other) const; + bool operator!=(const KernelBuildInfo &other) const; + public: static auto constexpr kInvalidFormat = "InvalidFormat"; diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc index d81a8c90cea827e51442b538c18402131c577bd3..571e70dca5a0607299daac49df470abbd5ed3955 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -26,6 +26,7 @@ #include "utils/utils.h" #include "kernel/common_utils.h" #include "utils/context/ms_context.h" +#include "pre_activate/common/helper.h" namespace mindspore { namespace opt { @@ -50,16 +51,11 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con return nullptr; } std::vector do_mask_node_list; - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_map = manager->node_users(); - auto iter = node_map.find(node); - if (iter == node_map.end()) { - MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!"; - } - auto gen_mask_output_nodes = iter->second; - for (const auto &output_node : gen_mask_output_nodes) { + auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode); + MS_EXCEPTION_IF_NULL(gen_mask_output_nodes); + for (const auto &output_node : *gen_mask_output_nodes) { if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { + MS_EXCEPTION_IF_NULL(output_node.first); auto output_cnode = output_node.first->cast(); do_mask_node_list.push_back(output_cnode); } @@ -76,11 +72,12 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con << " GenMask " << node->DebugString(); } } - RectifyKernelInfo(do_mask_node_list); + RectifyKernelInfo(do_mask_node_list, graph); return nullptr; } -void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list) const { +void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list, + const FuncGraphPtr &graph) const { std::map format_counter; std::string special_format; std::string convert_format; @@ -94,17 +91,6 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_ } else { format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; } - // if has two or more special format we need change all domask's format to default that can avoid insert more - // transdata - if (format_counter.size() > 2) { - convert_format = kOpFormat_DEFAULT; - break; - } - if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() && - special_format != do_mask_data_format) { - convert_format = kOpFormat_DEFAULT; - break; - } } if (format_counter.size() == 1) { return; @@ -112,17 +98,23 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_ if (convert_format.empty()) { convert_format = GetConvertFormat(format_counter); } - RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format); + RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph); } std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { - std::string convert_format; - const size_t counter = 0; + std::string convert_format = kOpFormat_DEFAULT; + size_t counter = 0; + if (format_counter.size() > 2) { + return kOpFormat_DEFAULT; + } + if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { + return kOpFormat_DEFAULT; + } for (const auto &iter : format_counter) { if (counter < iter.second) { convert_format = iter.first; - } - if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) { + counter = iter.second; + } else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) { convert_format = iter.first; } } @@ -130,13 +122,17 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &do_mask_node_list, - const std::string &format) const { + const std::string &format, + const FuncGraphPtr &graph) const { for (const auto &do_mask : do_mask_node_list) { - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); - builder->SetInputFormat(format, 0); - builder->SetOutputFormat(format, 0); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); + if (AnfAlgo::GetInputFormat(do_mask, 0) != format) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); + builder->SetInputFormat(format, 0); + builder->SetOutputFormat(format, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); + ReSelecChildNodeKernelInfo(do_mask, graph); + } } } @@ -159,5 +155,30 @@ AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const Anf } return nullptr; } + +void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const { + MS_EXCEPTION_IF_NULL(cnode); + auto output_node_list = GetRealNodeUsedList(graph, cnode); + MS_EXCEPTION_IF_NULL(output_node_list); + for (const auto &out_node_info : *output_node_list) { + MS_EXCEPTION_IF_NULL(out_node_info.first); + auto out_node = out_node_info.first->cast(); + if (AnfAlgo::IsRealKernel(out_node_info.first)) { + auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); + kernel_selecter->SelectKernel(out_node); + auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); + MS_EXCEPTION_IF_NULL(new_build_info); + MS_EXCEPTION_IF_NULL(ori_build_info); + if ((*new_build_info) != (*ori_build_info)) { + ReSelecChildNodeKernelInfo(out_node, graph); + } + } else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() || + AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) { + ReSelecChildNodeKernelInfo(out_node, graph); + } else { + MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed"; + } + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h index 81bad4d8f892a4a71e54e4b064a53715d0d28acd..b03937db477d932cb579434ad50b36c9d9307c2f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h @@ -19,23 +19,28 @@ #include #include #include +#include #include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ascend_helper.h" namespace mindspore { namespace opt { class RectifyDoMaskKernelInfo : public PatternProcessPass { public: explicit RectifyDoMaskKernelInfo(bool multigraph = true) - : PatternProcessPass("batch_norm_bert_fission", multigraph) {} + : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared()) {} ~RectifyDoMaskKernelInfo() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: - void RectifyKernelInfo(const std::vector &do_mask_node_list) const; + void RectifyKernelInfo(const std::vector &do_mask_node_list, const FuncGraphPtr &graph) const; AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; std::string GetConvertFormat(const std::map &format_counter) const; - void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format) const; + void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format, + const FuncGraphPtr &graph) const; + void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const; + KernelSelectPtr kernel_selecter; }; } // namespace opt } // namespace mindspore