提交 ee79023e 编写于 作者: W WilliamLian

clean pclint warning

上级 e9670f3c
......@@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto write_input = cnode->input(1);
if (CheckEltWiseNode(manager.get(), write_input)) {
(void)record.insert(write_input);
auto input_cnode = write_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
write_input = input_cnode->input(1);
}
MS_EXCEPTION_IF_NULL(write_input);
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
fusion_id_allocator->HasFusionIdAttr(write_input)) {
......@@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
fusion_id_allocator->HasFusionIdAttr(conv_input)) {
return;
}
if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) {
(void)record.insert(conv_input);
candidate_fusion->push_back(record);
......
......@@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) {
return nullptr;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
return nullptr;
return RectifyKernelInfoInPynativeProcess(node);
}
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {
return nullptr;
......@@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
}
return convert_format;
}
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const std::string &format) const {
for (const auto &do_mask : do_mask_node_list) {
......@@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C
}
}
AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) {
return nullptr;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
return nullptr;
}
} // namespace opt
} // namespace mindspore
......@@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
private:
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const;
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const;
};
......
......@@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
}
auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
while (index < input_num) {
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
++index;
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {
new_depend_inputs.push_back(replacing_node);
continue;
}
auto replacing_cnode = replacing_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replacing_cnode);
// Deal with the make_tuple with TransData or Cast inputs.
auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode);
if (make_tuple_replace_node != nullptr) {
new_depend_inputs.push_back(make_tuple_replace_node);
continue;
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) {
new_depend_inputs.push_back(replacing_node);
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: "
<< node->DebugString();
continue;
}
auto replace_node = GetConvertNode(func_graph, node, index);
MS_EXCEPTION_IF_NULL(replace_node);
new_depend_inputs.push_back(replace_node);
++index;
}
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend;
CNodePtr new_depend = nullptr;
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
......@@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
}
return new_depend;
}
const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
const size_t index) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto depend_cnode = node->cast<CNodePtr>();
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {
return replacing_node;
}
auto replacing_cnode = replacing_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replacing_cnode);
// Deal with the make_tuple with TransData or Cast inputs.
auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
if (make_tuple_replace_node != nullptr) {
return make_tuple_replace_node;
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) {
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
return replacing_node;
}
return replace_node;
}
} // namespace opt
} // namespace mindspore
......@@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass {
~OptimizeDependence() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const;
};
} // namespace opt
} // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册