提交 58403932 编写于 作者: G gukecai 提交者: yanghaoran

add sync bewteen hcom

上级 3277a63e
......@@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG
}
}
void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr);
auto cnode_ptr_list = graph_ptr->execution_order();
vector<uint32_t> fusion_hcom_index;
vector<CNodePtr> orders;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (IsHcom(cur_cnode)) {
fusion_hcom_index.emplace_back(i);
}
}
if (fusion_hcom_index.size() < 2) {
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
return;
}
uint32_t first_index = fusion_hcom_index[0];
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];
uint32_t cur_event_id = total_event_num_;
uint32_t pre_hcom_stream_id = UINT32_MAX;
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
for (size_t i = first_index; i <= last_index; i++) {
auto cur_cnode = cnode_ptr_list[i];
auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i);
if (it == fusion_hcom_index.end()) {
orders.emplace_back(cur_cnode);
continue;
}
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
if (cur_hcom_stream_id == pre_hcom_stream_id) {
orders.emplace_back(cur_cnode);
continue;
}
if (i == first_index) {
// first fusion hcom
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send);
} else if (i == last_index) {
// last fusion hcom
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
orders.emplace_back(cur_cnode);
cur_event_id++;
} else {
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
cur_event_id++;
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send);
}
pre_hcom_stream_id = cur_hcom_stream_id;
}
std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
graph_ptr->set_execution_order(orders);
total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]";
MS_LOG(INFO) << "end";
}
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr);
......@@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor
graph_ptr->set_execution_order(cnodes);
total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]";
// Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2)
InsertSendRecvForDiffHcom(graph_ptr);
MS_LOG(INFO) << "end";
}
......
......@@ -95,6 +95,7 @@ class AscendStreamAssign {
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册