提交 ee79023e 编写于 作者: W WilliamLian

clean pclint warning

上级 e9670f3c
...@@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con ...@@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode}; std::unordered_set<AnfNodePtr> record{cnode};
auto write_input = cnode->input(1); auto write_input = cnode->input(1);
if (CheckEltWiseNode(manager.get(), write_input)) { if (CheckEltWiseNode(manager.get(), write_input)) {
(void)record.insert(write_input); (void)record.insert(write_input);
auto input_cnode = write_input->cast<CNodePtr>(); auto input_cnode = write_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode); MS_EXCEPTION_IF_NULL(input_cnode);
write_input = input_cnode->input(1); write_input = input_cnode->input(1);
} }
MS_EXCEPTION_IF_NULL(write_input); MS_EXCEPTION_IF_NULL(write_input);
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) || if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
fusion_id_allocator->HasFusionIdAttr(write_input)) { fusion_id_allocator->HasFusionIdAttr(write_input)) {
...@@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con ...@@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
fusion_id_allocator->HasFusionIdAttr(conv_input)) { fusion_id_allocator->HasFusionIdAttr(conv_input)) {
return; return;
} }
if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) {
(void)record.insert(conv_input); (void)record.insert(conv_input);
candidate_fusion->push_back(record); candidate_fusion->push_back(record);
......
...@@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con ...@@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) { if (ms_context->execution_mode() == kPynativeMode) {
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { return RectifyKernelInfoInPynativeProcess(node);
return nullptr;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
return nullptr;
} }
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {
return nullptr; return nullptr;
...@@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string ...@@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
} }
return convert_format; return convert_format;
} }
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const std::string &format) const { const std::string &format) const {
for (const auto &do_mask : do_mask_node_list) { for (const auto &do_mask : do_mask_node_list) {
...@@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C ...@@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C
} }
} }
AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) {
return nullptr;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
return nullptr;
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass { ...@@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
private: private:
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const; void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const;
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const; std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const; void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const;
}; };
......
...@@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con ...@@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
} }
auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
while (index < input_num) { while (index < input_num) {
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); auto replace_node = GetConvertNode(func_graph, node, index);
++index; MS_EXCEPTION_IF_NULL(replace_node);
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {
new_depend_inputs.push_back(replacing_node);
continue;
}
auto replacing_cnode = replacing_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replacing_cnode);
// Deal with the make_tuple with TransData or Cast inputs.
auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode);
if (make_tuple_replace_node != nullptr) {
new_depend_inputs.push_back(make_tuple_replace_node);
continue;
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) {
new_depend_inputs.push_back(replacing_node);
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: "
<< node->DebugString();
continue;
}
new_depend_inputs.push_back(replace_node); new_depend_inputs.push_back(replace_node);
++index;
} }
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend; CNodePtr new_depend = nullptr;
if (kernel_graph == nullptr) { if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs); new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend); MS_EXCEPTION_IF_NULL(new_depend);
...@@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con ...@@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
} }
return new_depend; return new_depend;
} }
const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
const size_t index) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto depend_cnode = node->cast<CNodePtr>();
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {
return replacing_node;
}
auto replacing_cnode = replacing_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replacing_cnode);
// Deal with the make_tuple with TransData or Cast inputs.
auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
if (make_tuple_replace_node != nullptr) {
return make_tuple_replace_node;
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) {
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
return replacing_node;
}
return replace_node;
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass { ...@@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass {
~OptimizeDependence() override = default; ~OptimizeDependence() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册