diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 0744a0c2a305a2472c8526d57d997a5b90d026ae..99be7bd3a2168a7b7e7a45cdf1d3f27126a0c58d 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -426,23 +426,25 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { return true; } AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); - AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); // the streams' flag not HEAD_STREAM std::vector wait_active_stream_list; assign_instance.GetWaitStreams(&wait_active_stream_list); std::vector force_copy_stream_list; assign_instance.GetHcomStreams(&force_copy_stream_list); - MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_manager.GetCurAllocStreamNum() - << ", total event num:" << assign_instance.total_event_num() + + MS_LOG(INFO) << "call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() + << ", total event num:" << resource_manager.get_cur_event_num() << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); std::vector> empty_list; std::shared_ptr model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, stream_manager.GetCurAllocStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), - assign_instance.total_event_num(), 0); + 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), + resource_manager.get_cur_event_num(), 0); + auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index f0bad6b492e4208d07c199978c5ea1d9bb92d542..8f8f022bdba198ac82b630a1e20f66e46c2c86b9 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -34,110 +34,136 @@ namespace ascend { const uint32_t kHcomMaxTask = 5; const uint32_t kCommonMaxTask = 350; -void AscendStreamAssign::AssignStream(const shared_ptr &graph_ptr) { +void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { if (IsTaskSink()) { Reset(); ReorderIndependentOrders(graph_ptr); AssignAllNodesStream(graph_ptr); UpdateAtomicAddrCleanStreamId(graph_ptr); - FindHcomParallelStreams(graph_ptr); InsertStreamActive(graph_ptr); - InsertSendRecvForHcomParallel(graph_ptr); - InsertSendRecvForIndependent(graph_ptr); - UpdateEventId(graph_ptr); + InsertEventForHcomParallel(graph_ptr); + InsertEventForIndependentParallel(graph_ptr); GetNeedActiveStreams(graph_ptr); graph_ptr->PrintGraphExecuteOrder(); - CheckStreamAssign(graph_ptr); + CheckResourceAssign(graph_ptr); MS_LOG(INFO) << "after finish stream assign"; // Get info for D Model - AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); - generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); - generator::IRModelUtil::GetInstance().set_stream_num(stream_manager.GetCurAllocStreamNum()); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(resource_manager.get_cur_stream_num()); // Init to 1,temporarily generator::IRModelUtil::GetInstance().set_batch_num(1); } } -// section 0 -void AscendStreamAssign::CheckStreamAssign(const shared_ptr &graph_ptr) { - MS_EXCEPTION_IF_NULL(graph_ptr); - std::set streams; - uint32_t max_stream = 0; - uint32_t min_stream = kInvalidStreamId; - const std::vector &cnode_ptr_list = graph_ptr->execution_order(); +// section 1 +void AscendStreamAssign::ReorderIndependentOrders(const NotNull &graph_ptr) { + std::vector exe_orders; + std::vector independents; + std::vector others; + + auto cnode_ptr_list = graph_ptr->execution_order(); + MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size(); for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + auto cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (stream_id == kInvalidStreamId) { - MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "] had not been assigned streams"; + if (IsIndependentNode(cur_cnode_ptr)) { + independents.emplace_back(cur_cnode_ptr); + } else { + others.emplace_back(cur_cnode_ptr); } + } - streams.emplace(stream_id); - if (stream_id > max_stream) { - max_stream = stream_id; - } - if (stream_id < min_stream) { - min_stream = stream_id; - } + if (others.empty() || independents.empty()) { + MS_LOG(INFO) << "independent or others is empty, no need reorder"; + return; } - if (!streams.empty()) { - if (min_stream != 0) { - MS_LOG(EXCEPTION) << "before stream assign, assigned stream should start from 0, now is from " << min_stream; + std::set processed; + for (size_t i = 0; i < others.size(); i++) { + auto begin = others.begin() + i; + auto end = begin + 1; + bool flag = false; + for (size_t j = 0; j < independents.size(); j++) { + auto cur_independent = independents[j]; + auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); + if (it != processed.end()) { + continue; + } + + auto res = FindTargetOp(begin, end, cur_independent); + if (res != end) { + flag = true; + exe_orders.emplace_back(cur_independent); + exe_orders.emplace_back(*begin); + processed.emplace(cur_independent.get()); + break; + } } - if (max_stream != (streams.size() - 1)) { - MS_LOG(EXCEPTION) << "before stream assign, assigned stream should be consecutive"; + + if (!flag) { + exe_orders.emplace_back(*begin); } } + + MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size(); + if (processed.size() != independents.size()) { + MS_LOG(WARNING) << "processed independent nodes size is not equal to exiting independent nodes size"; + return; + } + + graph_ptr->set_execution_order(exe_orders); } -// section 1 -void AscendStreamAssign::AssignAllNodesStream(const shared_ptr &graph_ptr) { - MS_EXCEPTION_IF_NULL(graph_ptr); +// section 2 +void AscendStreamAssign::AssignAllNodesStream(const NotNull &graph_ptr) { auto cnode_ptr_list = graph_ptr->execution_order(); - CNodePtr pre_cnode_ptr = nullptr; - uint32_t cur_index = 0; - uint32_t cur_stream_id = 0; - bool exit_independent = false; - AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); + bool exit_hcom = false; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + // node has been assigned stream before if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { continue; } + + if (IsHcom(cur_cnode_ptr)) { + exit_hcom = true; + continue; + } + if (IsIndependentNode(cur_cnode_ptr)) { exit_independent = true; continue; } - // first common node, only exe one time - if (pre_cnode_ptr == nullptr) { - uint32_t cur_stream_num = stream_manager.GetCurAllocStreamNum(); - if (cur_stream_num == 0) { - cur_stream_id = stream_manager.ApplyNewStream(); - } else { - cur_stream_id = stream_manager.GetCurAllocStream(); + + AssignCommonStreamId(cur_cnode_ptr); + } + MS_LOG(INFO) << "common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); + + if (exit_hcom) { + uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + // node has been assigned stream before + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; } - ++cur_index; - pre_cnode_ptr = cur_cnode_ptr; - AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get()); + if (IsHcom(cur_cnode_ptr)) { - hcom_stream_list_.emplace(cur_stream_id); + AssignHcomStreamId(cur_cnode_ptr); } - continue; } - - AssignCommonStreamId(cur_cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id); + MS_LOG(INFO) << "hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); } if (exit_independent) { - uint32_t first_independent_stream_id = stream_manager.ApplyNewStream(); + uint32_t first_independ = resource_manager.ApplyNewStream(); for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { continue; } @@ -145,28 +171,75 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptrsecond < kCommonMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_common_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); + common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); + } } +} - MS_LOG(INFO) << "total stream nums:" << stream_manager.GetCurAllocStreamNum(); +void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); + auto it = hcom_stream_map_.find(cur_hcom_stream_id); + if (it == hcom_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); + hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + } else { + if (it->second < kHcomMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_hcom_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); + hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + } + } } void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); - uint32_t cur_independent_id = stream_manager.GetCurAllocStream(); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId(); auto it = independent_stream_map_.find(cur_independent_id); if (it == independent_stream_map_.end()) { AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.emplace(cur_independent_id, 1); + independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); } else { if (it->second < kCommonMaxTask) { AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); it->second++; } else { - cur_independent_id = stream_manager.ApplyNewStream(); + cur_independent_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.emplace(cur_independent_id, 1); + independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); } } } @@ -188,7 +261,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { return true; } - const std::vector &inputs = node_ptr->inputs(); + auto inputs = node_ptr->inputs(); for (size_t i = 1; i < inputs.size(); i++) { if (!inputs[i]->isa()) { return false; @@ -198,86 +271,105 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { return true; } -void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, - uint32_t *cur_index, uint32_t *cur_stream_id) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(pre_cnode_ptr); - MS_EXCEPTION_IF_NULL(*pre_cnode_ptr); - AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); - bool over_max_hcom_task = (IsHcom(cur_cnode_ptr) && (*cur_index) % kHcomMaxTask == 0); - bool over_max_common_task = (!IsHcom(cur_cnode_ptr) && (*cur_index) % kCommonMaxTask == 0); - bool pre_common_cur_hcom = (IsHcom(cur_cnode_ptr) && !IsHcom(*pre_cnode_ptr)); - bool pre_hcom_cur_common = (!IsHcom(cur_cnode_ptr) && IsHcom(*pre_cnode_ptr)); - if (over_max_hcom_task || over_max_common_task || pre_common_cur_hcom || pre_hcom_cur_common) { - *cur_index = 0; - *cur_stream_id = stream_manager.ApplyNewStream(); - } - - ++(*cur_index); - AnfAlgo::SetStreamId(*cur_stream_id, cur_cnode_ptr.get()); - *pre_cnode_ptr = cur_cnode_ptr; - - // record ll hcom streams as hcom stream has different stream flag - if (IsHcom(cur_cnode_ptr)) { - auto it = std::find(hcom_stream_list_.begin(), hcom_stream_list_.end(), *cur_stream_id); - if (it == hcom_stream_list_.end()) { - MS_LOG(INFO) << "hcom stream id:" << *cur_stream_id; - hcom_stream_list_.emplace(*cur_stream_id); - } - } -} - -// section 2: -void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const shared_ptr &graph_ptr) { +// section 3: +void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr) { MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); - const std::vector &cnode_ptr_list = graph_ptr->execution_order(); + auto cnode_ptr_list = graph_ptr->execution_order(); for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); // update AtomicAddrClean stream same witch the next node if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { - MS_LOG(INFO) << "update AtomicAddrClean stream id from[" << AnfAlgo::GetStreamId(cnode_ptr_list[i - 1]) - << "] to [" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "]"; AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); } } MS_LOG(INFO) << "end"; } -// section 3 -void AscendStreamAssign::FindHcomParallelStreams(const shared_ptr &graph_ptr) { - MS_EXCEPTION_IF_NULL(graph_ptr); +// section 4 +void AscendStreamAssign::InsertStreamActive(const NotNull &graph_ptr) { + MS_LOG(INFO) << "start"; + GetProcessedStream(graph_ptr); + std::vector update_cnode_list; CNodePtr cur_cnode_ptr = nullptr; CNodePtr pre_cnode_ptr = nullptr; uint32_t pre_stream_id = UINT32_MAX; + + bool independent_flag = !(independent_stream_map_.empty()); + bool hcom_flag = !(hcom_stream_map_.empty()); auto cnode_ptr_list = graph_ptr->execution_order(); - for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) { + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (i == 0) { - pre_cnode_ptr = cur_cnode_ptr; - pre_stream_id = cur_stream_id; + if (IsIndependentNode(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); continue; } - bool pre_fusion_hcom = IsFusionHcom(pre_cnode_ptr); - bool diff_stream = (pre_stream_id != cur_stream_id); - if (diff_stream && pre_fusion_hcom) { - inner_parallel_streams_.emplace_back(std::vector{pre_stream_id, cur_stream_id}); + if (IsHcom(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + bool processed = IsProcessedStream(cur_stream_id); + // 1)inner stream assign, need insert active op + if (!processed) { + MS_LOG(INFO) << "common stream active info:" << pre_stream_id << "->active" << cur_stream_id; + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + // 1.set stream id + AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); + // 2.set active stream ids + std::vector active_index_list{cur_stream_id}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); + update_cnode_list.emplace_back(active_ptr); } - pre_cnode_ptr = cur_cnode_ptr; + if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { + MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; + UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); + } else { + update_cnode_list.emplace_back(cur_cnode_ptr); + } + + processed_streams_.emplace(cur_stream_id); pre_stream_id = cur_stream_id; + pre_cnode_ptr = cur_cnode_ptr; } + graph_ptr->set_execution_order(update_cnode_list); + MS_LOG(INFO) << "end"; } -// section 4 -void AscendStreamAssign::UpdateStreamSwitch(const std::shared_ptr &graph_ptr, - const CNodePtr &switch_ptr, const vector &independent_stream, +void AscendStreamAssign::GetProcessedStream(const NotNull &graph_ptr) { + // 0 stream is activated at first + processed_streams_.emplace(0); + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + auto cur_cnode_ptr = cnode_ptr_list[i]; + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + auto true_stream_id = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); + processed_streams_.emplace(true_stream_id); + + auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); + if (value_ptr == nullptr) { + continue; + } + auto need_active = GetValue(value_ptr); + if (need_active) { + processed_streams_.emplace(cur_stream_id); + } + } + } + for (const auto &item : processed_streams_) { + MS_LOG(INFO) << "before active:" << item << " is been processed"; + } +} + +void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, vector *orders) { - MS_EXCEPTION_IF_NULL(orders); orders->emplace_back(switch_ptr); auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr); MS_EXCEPTION_IF_NULL(primitive); @@ -291,203 +383,270 @@ void AscendStreamAssign::UpdateStreamSwitch(const std::shared_ptrDebugString() << "]"; MS_EXCEPTION_IF_NULL(switch_ptr); auto true_stream_id = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); - MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_stream_id - << "]"; + MS_LOG(INFO) << "streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) + << "; active stream id:" << true_stream_id; CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - MS_LOG(INFO) << "start update StreamActive op[" << active_ptr->DebugString() << "]"; AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(independent_stream), active_ptr); - independent_stream_activated_ = true; + vector active_ids; + // active indepdent stream + for (const auto &item : independent_stream_map_) { + active_ids.emplace_back(item.first); + } + // active hcom stream + for (const auto &item : hcom_stream_map_) { + active_ids.emplace_back(item.first); + } + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); // update processed stream - for (auto &item : independent_stream) { - processed_streams_.emplace(item); + independent_stream_activated_ = true; + for (const auto &item : independent_stream_map_) { + processed_streams_.emplace(item.first); + } + + hcom_stream_activated_ = true; + for (const auto &item : hcom_stream_map_) { + processed_streams_.emplace(item.first); } orders->emplace_back(active_ptr); -} // namespace ascend +} -void AscendStreamAssign::InsertStreamActive(const std::shared_ptr &graph_ptr) { - MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); - std::vector update_cnode_list; - CNodePtr cur_cnode_ptr = nullptr; - CNodePtr pre_cnode_ptr = nullptr; - uint32_t pre_stream_id = UINT32_MAX; - std::vector independent_stream; - MS_LOG(INFO) << "independent stream size:" << independent_stream_map_.size(); - for (auto item : independent_stream_map_) { - independent_stream.emplace_back(item.first); +bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { + auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); + if (it != processed_streams_.end()) { + return true; } + return false; +} + +// section5 +void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { + MS_LOG(INFO) << "start"; + InsertEventCommonDependHcom(graph_ptr); + InsertEventHcomDependCommon(graph_ptr); + InsertEventHcomDependHcom(graph_ptr); + MS_LOG(INFO) << "end"; +} - bool independent_flag = !(independent_stream.empty()); +void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes = cnode_ptr_list; + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto it = cnodes.begin(); + while (it != cnodes.end() && (it + 1) != cnodes.end()) { + MS_EXCEPTION_IF_NULL(*it); + MS_EXCEPTION_IF_NULL(*(it + 1)); + if (IsHcom(*it) && !IsHcom(*(it + 1))) { + CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); + it = cnodes.insert(it + 1, send_cnode_ptr); + + auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); + if (target == cnodes.end()) { + MS_LOG(WARNING) << "hcom node:" << (*(it - 1))->fullname_with_scope() + << ", can't find target for insert recv op, no insert send/recv"; + it = cnodes.erase(it); + continue; + } + + if (IsHcom(*target)) { + it = cnodes.erase(it); + continue; + } + + // deal recv op + uint32_t stream_id = AnfAlgo::GetStreamId(*target); + CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); + (void)cnodes.insert(target, recv_cnode_ptr); + cur_event_id = resource_manager.ApplyNewEvent(); + } + ++it; + } + // one event allocated additional, should delete + resource_manager.DeleteEvent(); + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "after common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); +} - const std::vector &cnode_ptr_list = graph_ptr->execution_order(); +void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes; + CNodePtr cur_cnode_ptr = nullptr; + uint32_t pre_stream_id = UINT32_MAX; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (IsIndependentNode(cur_cnode_ptr)) { - update_cnode_list.emplace_back(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (i == 0) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; continue; } - bool inner_active = false; - if (pre_cnode_ptr != nullptr) { - inner_active = pre_stream_id != cur_stream_id && AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName && - AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName; + if (!IsHcom(cur_cnode_ptr)) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; } - bool processed = IsProcessedStream(cur_stream_id); - // 1)inner stream assign, need insert active op - if (inner_active && !processed) { - MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]"; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - // 1.set stream id - AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); - // 2.set active stream ids - std::vector active_index_list; - GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - update_cnode_list.emplace_back(active_ptr); + if (cur_stream_id == pre_stream_id) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; } - if (independent_flag && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { - MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; - UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, independent_stream, &update_cnode_list); - } else { - update_cnode_list.emplace_back(cur_cnode_ptr); + if (!IsHcom(cnode_ptr_list[i - 1])) { + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); + cnodes.emplace_back(send); + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); + cnodes.emplace_back(recv); + cnodes.emplace_back(cur_cnode_ptr); } - - processed_streams_.emplace(cur_stream_id); pre_stream_id = cur_stream_id; - pre_cnode_ptr = cur_cnode_ptr; } - graph_ptr->set_execution_order(update_cnode_list); - MS_LOG(INFO) << "end"; -} -bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { - auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); - if (it != processed_streams_.end()) { - return true; - } - return false; -} - -void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, - vector *parallel_streams) { - MS_EXCEPTION_IF_NULL(parallel_streams); - for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { - const auto &cur_parallel_streams = inner_parallel_streams_[i]; - auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); - if (it != cur_parallel_streams.end()) { - MS_LOG(INFO) << "stream id:" << cur_stream_id << " is parallel stream"; - for (size_t j = 0; j < cur_parallel_streams.size(); j++) { - if (cur_parallel_streams[j] == stream_acitve_id) { - MS_LOG(INFO) << "one of parallel stream id" << cur_parallel_streams[j] - << "is same with streamacvite stream id" << stream_acitve_id; - continue; - } - (*parallel_streams).emplace_back(cur_parallel_streams[j]); - processed_streams_.emplace(cur_parallel_streams[j]); - } - return; - } - } - - processed_streams_.emplace(cur_stream_id); - (*parallel_streams).push_back(cur_stream_id); + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "after hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); } -// section5 -void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr &graph_ptr) { - MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); +void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); - vector fusion_hcom_index; - vector orders; + uint32_t first_hcom_stream = kInvalidStreamId; + uint32_t last_hcom_stream = kInvalidStreamId; + // key: stream id, value:hcom index + std::map> hcom_index; for (size_t i = 0; i < cnode_ptr_list.size(); i++) { auto cur_cnode = cnode_ptr_list[i]; - if (IsFusionHcom(cur_cnode)) { - fusion_hcom_index.emplace_back(i); + if (!IsHcom(cur_cnode)) { + continue; + } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + auto it = hcom_index.find(cur_stream_id); + if (it != hcom_index.end()) { + hcom_index[cur_stream_id].emplace_back(i); + } else { + hcom_index[cur_stream_id] = {i}; + } + + // record first hcom stream id + if (first_hcom_stream == kInvalidStreamId) { + first_hcom_stream = cur_stream_id; + } + + // record last hcom stream id + if (cur_stream_id != last_hcom_stream) { + last_hcom_stream = cur_stream_id; } } - if (fusion_hcom_index.size() < 2) { - MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them"; + + if (hcom_index.size() < 2) { + MS_LOG(INFO) << "different stream hcom size is less than 2, no need insert event between them"; return; } - uint32_t first_index = fusion_hcom_index[0]; - uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1]; - uint32_t cur_event_id = total_event_num_; - uint32_t pre_hcom_stream_id = kInvalidStreamId; - std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders)); - for (size_t i = first_index; i <= last_index; i++) { + InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); + MS_LOG(INFO) << "after hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, + const map> &hcom_index, + uint32_t first_hcom_stream, uint32_t last_hcom_stream) { + vector orders; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); + size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); + std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); + for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { auto cur_cnode = cnode_ptr_list[i]; - auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i); - if (it == fusion_hcom_index.end()) { + if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) { orders.emplace_back(cur_cnode); continue; } auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); - if (cur_hcom_stream_id == pre_hcom_stream_id) { - orders.emplace_back(cur_cnode); - continue; - } - if (i == first_index) { + if (i == first_stream_last_index) { // first fusion hcom orders.emplace_back(cur_cnode); auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); orders.emplace_back(send); - } else if (i == last_index) { + } else if (i == last_stream_first_index) { // last fusion hcom auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); orders.emplace_back(recv); orders.emplace_back(cur_cnode); - cur_event_id++; } else { - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - cur_event_id++; - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); + auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); + if (cur_stream_hcom_size == 1) { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id = resource_manager.ApplyNewEvent(); + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else { + // current stream, first hcom:add recv op + if (i == hcom_index.at(cur_hcom_stream_id).front()) { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id = resource_manager.ApplyNewEvent(); + orders.emplace_back(cur_cnode); + } else if (i == hcom_index.at(cur_hcom_stream_id).back()) { + // current stream, last hcom:add send op + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else { + // current stream, not first and last op + orders.emplace_back(cur_cnode); + } + } } - pre_hcom_stream_id = cur_hcom_stream_id; } - std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); + std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); graph_ptr->set_execution_order(orders); - total_event_num_ = cur_event_id; - MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]\n end"; } -void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr &graph_ptr) { +bool AscendStreamAssign::IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, + size_t index) { + MS_EXCEPTION_IF_NULL(node_ptr); + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); + auto it = hcom_index.find(cur_hcom_stream_id); + if (it == hcom_index.end()) { + return false; + } + auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); + if (iter == hcom_index.at(cur_hcom_stream_id).end()) { + return false; + } + return true; +} + +// section6 +void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull &graph_ptr) { MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); vector cnodes = cnode_ptr_list; - uint32_t cur_event_id = 0; + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); auto it = cnodes.begin(); - while (it != cnodes.end() && (it + 1) != cnodes.end()) { + while (it != cnodes.end()) { MS_EXCEPTION_IF_NULL(*it); - MS_EXCEPTION_IF_NULL(*(it + 1)); - if (IsHcom(*it) && !IsHcom(*(it + 1))) { - bool is_fusion = IsFusionHcom(*it); - if (!is_fusion) { - ++it; - continue; - } + if (IsIndependentNode(*it)) { + MS_LOG(INFO) << "deal independent op[" << (*it)->DebugString() << "]"; CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); it = cnodes.insert(it + 1, send_cnode_ptr); auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); if (target == cnodes.end()) { - MS_LOG(WARNING) << "hcom node[" << (*(it - 1))->fullname_with_scope() - << "] can't find target for insert recv op, no insert send/recv"; + MS_LOG(DEBUG) << "independ node[" << (*(it - 1))->fullname_with_scope() + << "] can't find target for insert recv op, no insert send/recv"; it = cnodes.erase(it); continue; } @@ -496,67 +655,31 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptrset_execution_order(cnodes); - total_event_num_ = cur_event_id; - MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]"; - - // Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2) - InsertSendRecvForDiffHcom(graph_ptr); + MS_LOG(INFO) << "after independent parallel, total event nums:" << resource_manager.get_cur_event_num(); MS_LOG(INFO) << "end"; } -void AscendStreamAssign::UpdateEventId(const shared_ptr &graph_ptr) { - MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); +// section7 +void AscendStreamAssign::GetNeedActiveStreams(const NotNull &graph_ptr) { CNodePtr cur_cnode_ptr = nullptr; - // key:virutal event id, value:real event id - std::unordered_map event_id_map; - uint32_t event_id; auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { - auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(primitive); - event_id = GetValue(primitive->GetAttr(kAttrEventId)); - // before stream assign, send/recv event_id assign from kFirstEventId - if (event_id < kFirstEventId) { - continue; - } - auto it = event_id_map.find(event_id); - if (it == event_id_map.end()) { - event_id_map.insert(std::make_pair(event_id, total_event_num_)); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(total_event_num_), cur_cnode_ptr); - total_event_num_++; - } else { - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(it->second), cur_cnode_ptr); - } - } - } -} + // 1)first stream 0 should be actived first; + need_first_active_streams_.emplace_back(0); -void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { - MS_EXCEPTION_IF_NULL(graph_ptr); - CNodePtr cur_cnode_ptr = nullptr; - auto cnode_ptr_list = graph_ptr->execution_order(); - // 1)stream witch kStreamNeedActivedFirst attr should be actived; + // 2)stream witch kStreamNeedActivedFirst attr should be actived; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - ValuePtr value_ptr = nullptr; auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - if (primitive != nullptr) { - value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); - } else { - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(func_graph); - value_ptr = func_graph->get_attr(kStreamNeedActivedFirst); - } + MS_EXCEPTION_IF_NULL(primitive); + auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); if (value_ptr == nullptr) { continue; } @@ -569,20 +692,115 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { + CheckStreamAssign(graph_ptr); + CheckEventAssign(graph_ptr); +} + +void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + std::set streams; + uint32_t max_stream = 0; + uint32_t min_stream = kInvalidStreamId; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + if (stream_id == kInvalidStreamId) { + MS_LOG(EXCEPTION) << "node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream"; + } + + (void)streams.emplace(stream_id); + if (stream_id > max_stream) { + max_stream = stream_id; + } + if (stream_id < min_stream) { + min_stream = stream_id; + } + } + + // check stream assign + if (!streams.empty()) { + if (min_stream != 0) { + MS_LOG(EXCEPTION) << "stream should start from 0, now is from " << min_stream; + } + uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); + if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { + MS_LOG(EXCEPTION) << "stream should be consecutive, max stream id:" << max_stream + << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); + } + } } -CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, - uint32_t event_id, uint32_t stream_id) { - MS_EXCEPTION_IF_NULL(graph_ptr); +void AscendStreamAssign::CheckEventAssign(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + std::map> event_map; + uint32_t max_event_id = 0; + uint32_t min_event_id = kInvalidEventId; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr); + if (name == kSendOpName || name == kRecvOpName) { + uint32_t event_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId); + if (event_id > max_event_id) { + max_event_id = event_id; + } + + if (event_id < min_event_id) { + min_event_id = event_id; + } + auto it = event_map.find(event_id); + if (it == event_map.end()) { + event_map[event_id] = {cur_cnode_ptr}; + } else { + event_map[event_id].emplace_back(cur_cnode_ptr); + } + } + } + // check event assign + if (!event_map.empty()) { + if (min_event_id != 0) { + MS_LOG(EXCEPTION) << "event should start from 0, now is from " << min_event_id; + } + uint32_t assigned_event_num = resource_manager.get_cur_event_num(); + if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { + MS_LOG(EXCEPTION) << "event should be consecutive"; + } + for (const auto &item : event_map) { + if (item.second.size() != 2) { + MS_LOG(EXCEPTION) << "send/recv should be in pair and share one event id"; + } + auto first_name = AnfAlgo::GetCNodeName(item.second[0]); + auto second_name = AnfAlgo::GetCNodeName(item.second[1]); + if (!(first_name == kSendOpName && second_name == kRecvOpName)) { + MS_LOG(EXCEPTION) << "send should be before recv"; + } + } + } +} + +// section9 +CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, + uint32_t stream_id) { auto send_op = std::make_shared(kSendOpName); MS_EXCEPTION_IF_NULL(send_op); auto send_apply = std::make_shared(send_op); @@ -601,9 +819,8 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, - uint32_t event_id, uint32_t stream_id) { - MS_EXCEPTION_IF_NULL(graph_ptr); +CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, + uint32_t stream_id) { auto recv_op = std::make_shared(kRecvOpName); MS_EXCEPTION_IF_NULL(recv_op); auto recv_apply = std::make_shared(recv_op); @@ -649,42 +866,6 @@ vector::iterator AscendStreamAssign::FindTargetOp(vector::it ++begin; } return end; -} // namespace ascend - -void AscendStreamAssign::InsertSendRecvForIndependent(const shared_ptr &graph_ptr) { - MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes = cnode_ptr_list; - uint32_t cur_event_id = total_event_num_; - auto it = cnodes.begin(); - while (it != cnodes.end()) { - MS_EXCEPTION_IF_NULL(*it); - if (IsIndependentNode(*it)) { - MS_LOG(INFO) << "deal independent op[" << (*it)->DebugString() << "]"; - CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); - it = cnodes.insert(it + 1, send_cnode_ptr); - - auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); - if (target == cnodes.end()) { - MS_LOG(DEBUG) << "independ node[" << (*(it - 1))->fullname_with_scope() - << "] can't find target for insert recv op, no insert send/recv"; - it = cnodes.erase(it); - continue; - } - - // deal recv op - uint32_t stream_id = AnfAlgo::GetStreamId(*target); - CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); - (void)cnodes.insert(target, recv_cnode_ptr); - ++cur_event_id; - } - ++it; - } - graph_ptr->set_execution_order(cnodes); - total_event_num_ = cur_event_id; - MS_LOG(INFO) << "total event nums[" << total_event_num_ << "]"; - MS_LOG(INFO) << "end"; } bool AscendStreamAssign::IsTaskSink() { @@ -701,8 +882,8 @@ bool AscendStreamAssign::IsTaskSink() { void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { MS_EXCEPTION_IF_NULL(wait_active_stream_list); - AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); - uint32_t total_stream_num = stream_manager.GetCurAllocStreamNum(); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t total_stream_num = resource_manager.get_cur_stream_num(); if (total_stream_num == 0) { MS_LOG(INFO) << "total_common_stream_num is zero"; return; @@ -713,7 +894,7 @@ void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_lis auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); if (it == need_first_active_streams_.end()) { MS_LOG(INFO) << "wait common stream id = " << i; - (*wait_active_stream_list).push_back(i); + wait_active_stream_list->push_back(i); } } } @@ -723,94 +904,21 @@ bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; } -bool AscendStreamAssign::IsFusionHcom(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - bool is_hcom = IsHcom(cur_cnode_ptr); - if (!is_hcom) { - return false; - } - - if (!AnfAlgo::HasNodeAttr(kAttrFusion, cur_cnode_ptr)) { - return false; - } - - if (AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrFusion) == 0) { - return false; - } - - return true; -} - void AscendStreamAssign::GetHcomStreams(std::vector *streams) { MS_EXCEPTION_IF_NULL(streams); - for (const auto &stream : hcom_stream_list_) { - (*streams).emplace_back(stream); + for (const auto &item : hcom_stream_map_) { + streams->emplace_back(item.first); } } -void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr &graph_ptr) { - MS_EXCEPTION_IF_NULL(graph_ptr); - CNodePtr cur_cnode_ptr = nullptr; - std::vector exe_orders; - std::vector independents; - std::vector others; - auto cnode_ptr_list = graph_ptr->execution_order(); - MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (IsIndependentNode(cur_cnode_ptr)) { - independents.emplace_back(cur_cnode_ptr); - } else { - others.emplace_back(cur_cnode_ptr); - } - } - if (others.empty() || independents.empty()) { - MS_LOG(INFO) << "independent or others is empty, no need reorder"; - return; - } - - std::set processed; - for (size_t i = 0; i < others.size(); i++) { - auto begin = others.begin() + i; - auto end = begin + 1; - bool flag = false; - for (size_t j = 0; j < independents.size(); j++) { - auto cur_independent = independents[j]; - auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); - if (it != processed.end()) { - continue; - } - auto res = FindTargetOp(begin, end, cur_independent); - if (res != end) { - flag = true; - exe_orders.emplace_back(cur_independent); - exe_orders.emplace_back(*begin); - processed.emplace(cur_independent.get()); - break; - } - } - if (!flag) { - exe_orders.emplace_back(*begin); - } - } - MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size(); - if (processed.size() != independents.size()) { - MS_LOG(WARNING) << "processed independent nodes size is not equal to exiting independent nodes size"; - return; - } - - graph_ptr->set_execution_order(exe_orders); -} - void AscendStreamAssign::Reset() { - total_event_num_ = 0; independent_stream_activated_ = false; + hcom_stream_activated_ = false; independent_stream_map_.clear(); + hcom_stream_map_.clear(); + common_stream_map_.clear(); processed_streams_.clear(); - hcom_stream_list_.clear(); need_first_active_streams_.clear(); - inner_parallel_streams_.clear(); } } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h index bb918cfc79c7596e3af93c5354bc24a63b5dccd1..625ab6ad6e191703277112c694f346c980c9b631 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -29,6 +29,7 @@ #include "runtime/rt_model.h" #include "runtime/stream.h" #include "session/kernel_graph.h" +#include "utils/contract.h" namespace mindspore { namespace device { @@ -38,35 +39,59 @@ using std::shared_ptr; using std::unordered_map; using std::unordered_set; using std::vector; -using CnodeKey = void *; const uint32_t kInvalidStreamId = UINT32_MAX; -class AscendStreamMng { +const uint32_t kInvalidEventId = UINT32_MAX; +class AscendResourceMng { public: - static AscendStreamMng &GetInstance() { - static AscendStreamMng instance; + static AscendResourceMng &GetInstance() { + static AscendResourceMng instance; return instance; } - void Reset() { - cur_stream_id = 0; - cur_stream_num = 0; + void ResetResource() { + cur_stream_num_ = 0; + cur_event_num_ = 0; } uint32_t ApplyNewStream() { - if (!cur_stream_num) { - cur_stream_num++; + if (!cur_stream_num_) { + uint32_t cur_stream_id = cur_stream_num_; + cur_stream_num_++; return cur_stream_id; } - cur_stream_num++; - cur_stream_id++; + uint32_t cur_stream_id = cur_stream_num_; + cur_stream_num_++; return cur_stream_id; } + uint32_t ApplyNewEvent() { + if (!cur_event_num_) { + uint32_t cur_event_id = cur_event_num_; + cur_event_num_++; + return cur_event_id; + } + uint32_t cur_event_id = cur_event_num_; + cur_event_num_++; + return cur_event_id; + } - uint32_t GetCurAllocStream() { return cur_stream_id; } - uint32_t GetCurAllocStreamNum() { return cur_stream_num; } + void DeleteEvent() { + if (!cur_event_num_) { + MS_LOG(WARNING) << "total event num is 0, no event to delete"; + } else { + --cur_event_num_; + } + } + uint32_t get_cur_stream_num() { return cur_stream_num_; } + uint32_t GetCurAllocStreamId() { + if (!cur_stream_num_) { + MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; + } + return cur_stream_num_ - 1; + } + uint32_t get_cur_event_num() { return cur_event_num_; } private: - uint32_t cur_stream_num{0}; - uint32_t cur_stream_id{0}; + uint32_t cur_stream_num_{0}; + uint32_t cur_event_num_{0}; }; class AscendStreamAssign { @@ -79,39 +104,42 @@ class AscendStreamAssign { AscendStreamAssign(const AscendStreamAssign &) = delete; AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; - uint32_t total_event_num() const { return total_event_num_; } + void AssignStream(const NotNull &graph_ptr); void GetHcomStreams(std::vector *streams); - - void AssignStream(const std::shared_ptr &graph_ptr); void GetWaitStreams(vector *wait_active_stream_list); - CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, - uint32_t stream_id); - CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, - uint32_t stream_id); + CNodePtr CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); + CNodePtr CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); private: AscendStreamAssign() = default; ~AscendStreamAssign() = default; void Reset(); - void CheckStreamAssign(const std::shared_ptr &graph_ptr); - void AssignAllNodesStream(const std::shared_ptr &graph_ptr); - void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, - uint32_t *cur_stream_id); + void CheckResourceAssign(const NotNull &graph_ptr); + void CheckStreamAssign(const NotNull &graph_ptr); + void CheckEventAssign(const NotNull &graph_ptr); + void AssignAllNodesStream(const NotNull &graph_ptr); + void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); + void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); - void UpdateAtomicAddrCleanStreamId(const std::shared_ptr &graph_ptr); - void FindHcomParallelStreams(const std::shared_ptr &graph_ptr); - void InsertStreamActive(const std::shared_ptr &graph_ptr); - void UpdateStreamSwitch(const std::shared_ptr &graph_ptr, const CNodePtr &switch_ptr, - const vector &independent_stream, vector *orders); - void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); - void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); - void InsertSendRecvForDiffHcom(const shared_ptr &graph_ptr); - void UpdateEventId(const std::shared_ptr &graph_ptr); - void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); - void ReorderIndependentOrders(const std::shared_ptr &graph_ptr); + void UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr); + void FindHcomParallelStreams(const NotNull &graph_ptr); + void InsertStreamActive(const NotNull &graph_ptr); + void UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, + vector *orders); + void InsertEventForIndependentParallel(const NotNull &graph_ptr); + void InsertEventForHcomParallel(const NotNull &graph_ptr); + void InsertEventCommonDependHcom(const NotNull &graph_ptr); + void InsertEventHcomDependCommon(const NotNull &graph_ptr); + void InsertEventHcomDependHcom(const NotNull &graph_ptr); + void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, + uint32_t first_hcom_stream, uint32_t last_hcom_stream); + bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); + + void GetProcessedStream(const NotNull &graph_ptr); + void GetNeedActiveStreams(const NotNull &graph_ptr); + void ReorderIndependentOrders(const NotNull &graph_ptr); bool IsTaskSink(); - bool IsFusionHcom(const CNodePtr &cur_cnode_ptr); bool IsHcom(const CNodePtr &cur_cnode_ptr); bool IsIndependentNode(const CNodePtr &node_ptr); bool IsProcessedStream(uint32_t stream_id); @@ -119,14 +147,13 @@ class AscendStreamAssign { const CNodePtr &node); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); - uint32_t total_event_num_{0}; bool independent_stream_activated_{false}; + bool hcom_stream_activated_{false}; std::map independent_stream_map_{}; + std::map hcom_stream_map_{}; + std::map common_stream_map_{}; std::set processed_streams_{}; - std::set hcom_stream_list_{}; std::vector need_first_active_streams_{}; - std::vector> inner_parallel_streams_{}; - // new policy end }; } // namespace ascend diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index cfccfb35067be50aca30d47333536b4b8e35fed4..f4fe64b4df67696c28725b6178718be0c7217c21 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -103,8 +103,8 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &kernel_graph_ptr) { - device::ascend::AscendStreamMng &stream_manager = device::ascend::AscendStreamMng::GetInstance(); - stream_manager.Reset(); + device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); + resource_manager.ResetResource(); if (!NeedInsertSwitch()) { return; } @@ -135,17 +135,16 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr } std::vector exec_order; - // getnext loop process // getnext loop stream switch op CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); MS_EXCEPTION_IF_NULL(getnext_switch_app); - uint32_t getnext_switch_stream_id = stream_manager.ApplyNewStream(); + uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); exec_order.push_back(getnext_switch_app); // getnext op - uint32_t getnext_stream_id = stream_manager.ApplyNewStream(); + uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); size_t i = 0; for (; i < orders.size(); i++) { auto node = orders[i]; @@ -160,7 +159,8 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); // getnext loop send - CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); + uint32_t getnext_event_id = resource_manager.ApplyNewEvent(); + CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id); AnfAlgo::SetStreamId(getnext_stream_id, send.get()); exec_order.push_back(send); @@ -168,14 +168,14 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr // fpbp loop stream switch CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); MS_EXCEPTION_IF_NULL(fpbp_switch_app); - uint32_t fpbp_switch_stream_id = stream_manager.ApplyNewStream(); + uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); exec_order.push_back(fpbp_switch_app); // fpbp loop recv - CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); - uint32_t fpbp_stream_id = stream_manager.ApplyNewStream(); + CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id); + uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); exec_order.push_back(recv); diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index 1a7436b3968574b13ec7e7605e2cd8993d97ac87..5dc559408a1f1baf90ff87e9c5ec7d4d2f091402 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -38,12 +38,8 @@ constexpr auto kIterLoopParamName = "iter_loop"; constexpr auto kZeroParamName = "zero"; constexpr auto kOneParamName = "one"; constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; +constexpr uint32_t kSecondStreamSwitchLabel = 2; -const uint32_t kFirstStreamSwitchLabel = 0; -const uint32_t kGetNextLabel = 1; -const uint32_t kSecondStreamSwitchLabel = 2; -const uint32_t kInvalidEventId = UINT32_MAX; -const uint32_t kFirstEventId = kInvalidEventId / 2; namespace device { class KernelAdjust { public: diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 9ae16a1dbd617c787c89b6fa8e6097031353a541..5ed5d96acdee2813b9e10fa8a3511f50a351eaa6 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -305,7 +305,7 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { // adjust kernel AdjustKernel(root_graph); // assign stream - AssignStream(root_graph); + AssignStream(NOT_NULL(root_graph)); // insert profiling point device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); // build kernel @@ -377,7 +377,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { // adjust execution order because merge child graph and other special operations AdjustKernel(graph); // Assign streams for control sink and hccl and so on - AssignStream(graph); + AssignStream(NOT_NULL(graph)); device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); // build kernel if node is cnode @@ -647,7 +647,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr &kernel MS_LOG(INFO) << "Finish!"; } -void AscendSession::AssignStream(const std::shared_ptr &kernel_graph) const { +void AscendSession::AssignStream(NotNull kernel_graph) const { MS_LOG(INFO) << "Start!"; device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); MS_LOG(INFO) << "Finish!"; diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index eaa01b8f8047cab2e8595c76ce811e3a898230cc..ec85d82439b8540c280fa13e78aa6383cc790345 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -76,7 +76,7 @@ class AscendSession : public SessionBasic { void HardwareOptimize(const std::shared_ptr &kernel_graph) const; void AdjustKernel(const std::shared_ptr &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; - void AssignStream(const std::shared_ptr &kernel_graph) const; + void AssignStream(NotNull kernel_graph) const; void AssignLabel(NotNull kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index fba52323cf4df167fe347b4b565aaf7ba2f53b2a..a6ec3a50b5cef75e03a4d54fb5744d6f7290b648 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -26,7 +26,7 @@ void AscendLabelAssign::AssignLabel(NotNull graph) { return 1; } uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { return 1; } -void AscendStreamAssign::AssignStream(const KernelGraphPtr &graph) { return; } +void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { return; } void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { return; }