diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 274b355679a4ce2ea7d7d036774d2806e63982de..4c6c7ab9cf41121b6c8a5d18d549dca443a00e65 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -18,9 +18,12 @@ #include #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "utils/union_find_set.h" #include "runtime/device/ascend/ascend_label_assign.h" +#include "utils/context/ms_context.h" +#include "debug/anf_ir_dump.h" static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodeCallArg = 1; @@ -248,10 +251,14 @@ void AscendControlParser::EraseParameter(NotNull root_graph, } MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); - - auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first; - parameter_count.AddReadCount(source, -1); + auto source = assign_node->input(kCNodeAssignSource); + MS_EXCEPTION_IF_NULL(source); + auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; parameter_count.AddWriteCount(para, -1); + parameter_count.AddReadCount(para, -1); + if (visit_source->isa()) { + parameter_count.AddReadCount(visit_source, read - 1); + } for (auto &node : all_nodes) { for (size_t i = 0; i < node->size(); ++i) { if (node->input(i) == para) { @@ -260,8 +267,6 @@ void AscendControlParser::EraseParameter(NotNull root_graph, } } } - parameter_count.AddReadCount(source, 1); - parameter_count.AddReadCount(para, -1); } root_graph->set_execution_order(exec_order); } @@ -318,6 +323,17 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { (void)RecurseGraph(root_graph, NOT_NULL(&memo)); EraseParameter(root_graph, memo); EraseLabel(root_graph); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (context_ptr->save_graphs_flag()) { + std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir"; + DumpIR(file_path, root_graph.get()); + } } std::vector>> AscendControlParser::ParseCallNode(