提交 f92c4a53 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2862 reselect the kernel after rectify the build info of domask

Merge pull request !2862 from lianliguang/reselect-the-kernel-info-after-rectify-domask-kernel
...@@ -119,6 +119,8 @@ bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_ ...@@ -119,6 +119,8 @@ bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_
bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_); MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->kernel_type_ = kernel_type; kernel_build_info_->kernel_type_ = kernel_type;
......
...@@ -85,6 +85,8 @@ class KernelBuildInfo { ...@@ -85,6 +85,8 @@ class KernelBuildInfo {
bool operator==(const KernelBuildInfo &other) const; bool operator==(const KernelBuildInfo &other) const;
bool operator!=(const KernelBuildInfo &other) const;
public: public:
static auto constexpr kInvalidFormat = "InvalidFormat"; static auto constexpr kInvalidFormat = "InvalidFormat";
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "kernel/common_utils.h" #include "kernel/common_utils.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "pre_activate/common/helper.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -50,16 +51,11 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con ...@@ -50,16 +51,11 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
return nullptr; return nullptr;
} }
std::vector<CNodePtr> do_mask_node_list; std::vector<CNodePtr> do_mask_node_list;
auto manager = graph->manager(); auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode);
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(gen_mask_output_nodes);
auto node_map = manager->node_users(); for (const auto &output_node : *gen_mask_output_nodes) {
auto iter = node_map.find(node);
if (iter == node_map.end()) {
MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!";
}
auto gen_mask_output_nodes = iter->second;
for (const auto &output_node : gen_mask_output_nodes) {
if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) {
MS_EXCEPTION_IF_NULL(output_node.first);
auto output_cnode = output_node.first->cast<CNodePtr>(); auto output_cnode = output_node.first->cast<CNodePtr>();
do_mask_node_list.push_back(output_cnode); do_mask_node_list.push_back(output_cnode);
} }
...@@ -76,11 +72,12 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con ...@@ -76,11 +72,12 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
<< " GenMask " << node->DebugString(); << " GenMask " << node->DebugString();
} }
} }
RectifyKernelInfo(do_mask_node_list); RectifyKernelInfo(do_mask_node_list, graph);
return nullptr; return nullptr;
} }
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const { void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const FuncGraphPtr &graph) const {
std::map<std::string, size_t> format_counter; std::map<std::string, size_t> format_counter;
std::string special_format; std::string special_format;
std::string convert_format; std::string convert_format;
...@@ -94,17 +91,6 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_ ...@@ -94,17 +91,6 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
} else { } else {
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1;
} }
// if has two or more special format we need change all domask's format to default that can avoid insert more
// transdata
if (format_counter.size() > 2) {
convert_format = kOpFormat_DEFAULT;
break;
}
if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() &&
special_format != do_mask_data_format) {
convert_format = kOpFormat_DEFAULT;
break;
}
} }
if (format_counter.size() == 1) { if (format_counter.size() == 1) {
return; return;
...@@ -112,17 +98,23 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_ ...@@ -112,17 +98,23 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
if (convert_format.empty()) { if (convert_format.empty()) {
convert_format = GetConvertFormat(format_counter); convert_format = GetConvertFormat(format_counter);
} }
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format); RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph);
} }
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const { std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
std::string convert_format; std::string convert_format = kOpFormat_DEFAULT;
const size_t counter = 0; size_t counter = 0;
if (format_counter.size() > 2) {
return kOpFormat_DEFAULT;
}
if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) {
return kOpFormat_DEFAULT;
}
for (const auto &iter : format_counter) { for (const auto &iter : format_counter) {
if (counter < iter.second) { if (counter < iter.second) {
convert_format = iter.first; convert_format = iter.first;
} counter = iter.second;
if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) { } else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) {
convert_format = iter.first; convert_format = iter.first;
} }
} }
...@@ -130,13 +122,17 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string ...@@ -130,13 +122,17 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
} }
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const std::string &format) const { const std::string &format,
const FuncGraphPtr &graph) const {
for (const auto &do_mask : do_mask_node_list) { for (const auto &do_mask : do_mask_node_list) {
auto builder = if (AnfAlgo::GetInputFormat(do_mask, 0) != format) {
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); auto builder =
builder->SetInputFormat(format, 0); std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask));
builder->SetOutputFormat(format, 0); builder->SetInputFormat(format, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); builder->SetOutputFormat(format, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get());
ReSelecChildNodeKernelInfo(do_mask, graph);
}
} }
} }
...@@ -159,5 +155,30 @@ AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const Anf ...@@ -159,5 +155,30 @@ AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const Anf
} }
return nullptr; return nullptr;
} }
void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(cnode);
auto output_node_list = GetRealNodeUsedList(graph, cnode);
MS_EXCEPTION_IF_NULL(output_node_list);
for (const auto &out_node_info : *output_node_list) {
MS_EXCEPTION_IF_NULL(out_node_info.first);
auto out_node = out_node_info.first->cast<CNodePtr>();
if (AnfAlgo::IsRealKernel(out_node_info.first)) {
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node);
kernel_selecter->SelectKernel(out_node);
auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node);
MS_EXCEPTION_IF_NULL(new_build_info);
MS_EXCEPTION_IF_NULL(ori_build_info);
if ((*new_build_info) != (*ori_build_info)) {
ReSelecChildNodeKernelInfo(out_node, graph);
}
} else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() ||
AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) {
ReSelecChildNodeKernelInfo(out_node, graph);
} else {
MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed";
}
}
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -19,23 +19,28 @@ ...@@ -19,23 +19,28 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "pre_activate/common/optimizer.h" #include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class RectifyDoMaskKernelInfo : public PatternProcessPass { class RectifyDoMaskKernelInfo : public PatternProcessPass {
public: public:
explicit RectifyDoMaskKernelInfo(bool multigraph = true) explicit RectifyDoMaskKernelInfo(bool multigraph = true)
: PatternProcessPass("batch_norm_bert_fission", multigraph) {} : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared<KernelSelect>()) {}
~RectifyDoMaskKernelInfo() override = default; ~RectifyDoMaskKernelInfo() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private: private:
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const; void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const FuncGraphPtr &graph) const;
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const; std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const; void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format,
const FuncGraphPtr &graph) const;
void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const;
KernelSelectPtr kernel_selecter;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册