diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc index 39e9feb73d7450ca3c228b88cf2fb7a513a1f8ad..c985020091f4dbdac31a24e23707f6de2b5acf43 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -148,8 +148,8 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull gr std::lock_guard lock(label_num_mutex_); auto iter = label_num_.find(graph.get()); if (iter == label_num_.end()) { - MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 1."; - return 1; + MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0."; + return 0; } return iter->second; } diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index b0dc1cc523c6afef5b50b83252f34f99d0d06028..02217291f3b89467b9ca709fdcea58dc04c61a32 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -40,7 +40,7 @@ static void InitUnionFindSet(NotNull kg, const NotNullinsert(kg.get()); - const std::map> &real_inputs = kg->real_inputs(); + const std::map> &real_inputs = kg->real_inputs(); for (auto &iter : real_inputs) { auto ¶ = iter.first; if (para->isa()) { @@ -65,7 +65,7 @@ static void UnionParentParameter(NotNull kg, const NotNullinsert(kg.get()); - const std::map> &real_inputs = kg->real_inputs(); + const std::map> &real_inputs = kg->real_inputs(); for (auto &iter : real_inputs) { auto ¶ = iter.first; for (auto &arg : iter.second) { @@ -174,10 +174,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapreal_inputs(); - for (auto &it : real_inputs) { - auto ¶meter = it.first; - auto &args = it.second; + const std::map> &real_inputs = kg->real_inputs(); + for (auto &in : kg->inputs()) { + auto it = real_inputs.find(in); + if (it == real_inputs.end()) { + continue; + } + auto ¶meter = it->first; + auto &args = it->second; for (auto &arg : args) { MS_EXCEPTION_IF_NULL(arg); if (arg->isa()) { diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 9adf3ca97b980819dde8661f791fd847c56d788f..9f3b5bbac48e0a2ed7d81bf105bb282aee2e2cb5 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -677,13 +677,13 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(arg); if (real_inputs_.find(parameter) == real_inputs_.end()) { - real_inputs_[parameter] = std::set(); + real_inputs_[parameter] = std::vector(); } auto &args = real_inputs_[parameter]; - (void)args.insert(arg); + (void)args.push_back(arg); } -std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { +std::vector KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { MS_EXCEPTION_IF_NULL(parameter); auto iter = real_inputs_.find(parameter); if (iter != real_inputs_.end()) { @@ -694,7 +694,7 @@ std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; - std::map> real_inputs_map; + std::map> real_inputs_map; for (auto &it : real_inputs_) { auto parameter = it.first; MS_EXCEPTION_IF_NULL(parameter); @@ -713,12 +713,18 @@ void KernelGraph::UpdateCallRealInput() { } for (auto &erase_node : erase_real_inputs) { MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString(); - (void)real_inputs.erase(erase_node); + for (auto iter = real_inputs.begin(); iter != real_inputs.end();) { + if (*iter == erase_node) { + iter = real_inputs.erase(iter); + } else { + ++iter; + } + } } for (auto &new_real_input : new_real_inputs) { MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " insert real input:" << new_real_input->DebugString(); - (void)real_inputs.insert(new_real_input); + (void)real_inputs.push_back(new_real_input); } real_inputs_map[parameter] = real_inputs; } @@ -730,18 +736,28 @@ void KernelGraph::PrintGraphExecuteOrder() const { for (size_t i = 0; i < execution_order_.size(); i++) { CNodePtr cur_cnode_ptr = execution_order_[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { - auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - MS_LOG(INFO) << "index[" << i << "], node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" - << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" - << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id[" - << GetValue(primitive->GetAttr(kAttrEventId)) << "], node info[" - << cur_cnode_ptr->DebugString() << "]"; - } else { - MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" - << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" - << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"; + std::string event_str; + std::string label_str; + if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) { + event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId)) + "]"; + } + + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) { + label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrLabelIndex)) + "]"; } + + if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) { + auto label_list = AnfAlgo::GetNodeAttr>(cur_cnode_ptr, kAttrLabelSwitchList); + label_str = ", label_id["; + for (size_t j = 0; j < label_list.size(); ++j) { + label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]"); + } + } + + MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" + << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" + << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" + << event_str << label_str; } } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 2cd7a2340a3d12b5f101fa2f19678c24c64b7697..9c52020898ac818b58741abf0ca14e7fe9e10ce3 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -127,8 +127,8 @@ class KernelGraph : public FuncGraph { // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; // get real inputs - const std::map> &real_inputs() const { return real_inputs_; } - std::set GetRealInput(const AnfNodePtr ¶meter); + const std::map> &real_inputs() const { return real_inputs_; } + std::vector GetRealInput(const AnfNodePtr ¶meter); void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); // used to dump ir std::string ToString() const override; @@ -194,7 +194,7 @@ class KernelGraph : public FuncGraph { // parameter graph std::shared_ptr parent_graph_; // record real parameters,inputs_ is the formal parameters - std::map> real_inputs_; + std::map> real_inputs_; CNodePtr start_label_; CNodePtr end_goto_;