diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc index 765b0d8c6649e46e0ba3d57c0f27078db47f01e7..5bc0fdced7e96eb296909a3763c6d0860ffe3947 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc @@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con MS_EXCEPTION_IF_NULL(manager); std::unordered_set record{cnode}; auto write_input = cnode->input(1); - if (CheckEltWiseNode(manager.get(), write_input)) { (void)record.insert(write_input); auto input_cnode = write_input->cast(); MS_EXCEPTION_IF_NULL(input_cnode); write_input = input_cnode->input(1); } - MS_EXCEPTION_IF_NULL(write_input); if (!write_input->isa() || !AnfAlgo::IsRealCNodeKernel(write_input) || fusion_id_allocator->HasFusionIdAttr(write_input)) { @@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con fusion_id_allocator->HasFusionIdAttr(conv_input)) { return; } - if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { (void)record.insert(conv_input); candidate_fusion->push_back(record); diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc index 32e4987f5a29493bb3d362b472ab13e7cdd19afd..d81a8c90cea827e51442b538c18402131c577bd3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->execution_mode() == kPynativeMode) { - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { - return nullptr; - } - auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); - if (do_mask_input_format != kOpFormat_DEFAULT) { - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); - builder->SetInputFormat(kOpFormat_DEFAULT, 0); - builder->SetOutputFormat(kOpFormat_DEFAULT, 0); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - } - return nullptr; + return RectifyKernelInfoInPynativeProcess(node); } if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { return nullptr; @@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &do_mask_node_list, const std::string &format) const { for (const auto &do_mask : do_mask_node_list) { @@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vectorcast(); + if (cnode == nullptr) { + return nullptr; + } + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { + return nullptr; + } + auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); + if (do_mask_input_format != kOpFormat_DEFAULT) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + builder->SetInputFormat(kOpFormat_DEFAULT, 0); + builder->SetOutputFormat(kOpFormat_DEFAULT, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + return nullptr; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h index 83f7e397bd4025309fdcc3e88568cfc301f6cc2e..81bad4d8f892a4a71e54e4b064a53715d0d28acd 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h @@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass { private: void RectifyKernelInfo(const std::vector &do_mask_node_list) const; + AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; std::string GetConvertFormat(const std::map &format_counter) const; void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format) const; }; diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc index ee480b9c8667d3fea3c258339f79bf67f4db1ad1..1d5f909e7d28c1c28a7039089eef7dc15f523ce8 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc @@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con } auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); while (index < input_num) { - auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); - ++index; - MS_EXCEPTION_IF_NULL(replacing_node); - if (!replacing_node->isa()) { - new_depend_inputs.push_back(replacing_node); - continue; - } - auto replacing_cnode = replacing_node->cast(); - MS_EXCEPTION_IF_NULL(replacing_cnode); - // Deal with the make_tuple with TransData or Cast inputs. - auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode); - if (make_tuple_replace_node != nullptr) { - new_depend_inputs.push_back(make_tuple_replace_node); - continue; - } - AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); - if (replace_node == nullptr) { - new_depend_inputs.push_back(replacing_node); - MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " - << node->DebugString(); - continue; - } + auto replace_node = GetConvertNode(func_graph, node, index); + MS_EXCEPTION_IF_NULL(replace_node); new_depend_inputs.push_back(replace_node); + ++index; } auto kernel_graph = func_graph->cast>(); - CNodePtr new_depend; + CNodePtr new_depend = nullptr; if (kernel_graph == nullptr) { new_depend = func_graph->NewCNode(new_depend_inputs); MS_EXCEPTION_IF_NULL(new_depend); @@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con } return new_depend; } + +const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, + const size_t index) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto depend_cnode = node->cast(); + auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); + MS_EXCEPTION_IF_NULL(replacing_node); + if (!replacing_node->isa()) { + return replacing_node; + } + auto replacing_cnode = replacing_node->cast(); + MS_EXCEPTION_IF_NULL(replacing_cnode); + // Deal with the make_tuple with TransData or Cast inputs. + auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode); + if (make_tuple_replace_node != nullptr) { + return make_tuple_replace_node; + } + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); + if (replace_node == nullptr) { + MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); + return replacing_node; + } + return replace_node; +} + } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h index d2995cdd302b5f9a87550bb2cfbdd204aa7d6bb8..30027b790aabdf55c7ed3f2c502acbb59f30fed3 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h @@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass { ~OptimizeDependence() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const; }; } // namespace opt } // namespace mindspore