diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 08b23f3d673dcac5020e05ce4b518a05773ac172..ac2b09bcfb793c8b7d0b2996c5b64f22ae2cfe4d 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -261,17 +261,16 @@ void AscendControlParser::EraseParameter(NotNull root_graph, } } - EraseAssign(all_nodes, para_to_written_node, root_graph); - root_graph->set_execution_order(exec_order); + EraseAssign(std::make_shared(parameter_count), all_nodes, para_to_written_node, root_graph); } -void AscendControlParser::EraseAssign(const std::set &all_nodes, +void AscendControlParser::EraseAssign(std::shared_ptr parameter_count, + const std::set &all_nodes, const std::map ¶_to_written_node, NotNull root_graph) { std::vector exec_order = root_graph->execution_order(); - ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; }); - while (parameter_count.HasValidElem()) { - auto [para, read, written] = parameter_count.GetOneValidElem(); + while (parameter_count->HasValidElem()) { + auto [para, read, written] = parameter_count->GetOneValidElem(); MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; auto assign_iter = para_to_written_node.find(para); if (assign_iter == para_to_written_node.end()) { @@ -280,7 +279,7 @@ void AscendControlParser::EraseAssign(const std::set &all_nodes, auto &assign_node = assign_iter->second; MS_EXCEPTION_IF_NULL(assign_node); if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) { - parameter_count.EraseElem(para); + parameter_count->EraseElem(para); continue; } MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); @@ -288,10 +287,10 @@ void AscendControlParser::EraseAssign(const std::set &all_nodes, auto source = assign_node->input(kCNodeAssignSource); MS_EXCEPTION_IF_NULL(source); auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; - parameter_count.AddWriteCount(para, -1); - parameter_count.AddReadCount(para, -1); + parameter_count->AddWriteCount(para, -1); + parameter_count->AddReadCount(para, -1); if (visit_source->isa()) { - parameter_count.AddReadCount(visit_source, read - 1); + parameter_count->AddReadCount(visit_source, read - 1); } for (auto &node : all_nodes) { for (size_t i = 0; i < node->size(); ++i) { @@ -302,6 +301,7 @@ void AscendControlParser::EraseAssign(const std::set &all_nodes, } } } + root_graph->set_execution_order(exec_order); } void AscendControlParser::EraseLabel(NotNull root_graph) { diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index 555de4162297eecd4f901537e6b94b700dbe62b2..4a7a18817b4a862d1e8d0a885180184da428bb00 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "backend/session/kernel_graph.h" #include "base/base_ref.h" #include "utils/contract.h" @@ -44,7 +45,7 @@ class AscendControlParser { class ReferenceCounter; static void EraseParameter(NotNull root_graph, const std::set &graph_list); - static void EraseAssign(const std::set &all_nodes, + static void EraseAssign(std::shared_ptr parameter_count, const std::set &all_nodes, const std::map ¶_to_written_node, NotNull root_graph); static void EraseLabel(NotNull root_graph);