提交 b69b1ca8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4830 [gpu] fix continuous allreduces bug

Merge pull request !4830 from yuchaojie/gpu_allreduce
......@@ -74,9 +74,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
continue;
}
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) {
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
}
// Find node which uses AllReduce as input[0].
std::vector<CNodePtr>::iterator mock_recv_node_iter =
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
......@@ -84,9 +86,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node.";
return false;
}
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) {
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
}
}
}
return true;
......@@ -110,17 +114,22 @@ std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
StreamSwitchType stream_switch_type) {
MS_EXCEPTION_IF_NULL(mock_send_node);
auto ret = end;
for (auto iter = begin; iter != end; iter++) {
auto node = *iter;
if (stream_switch_type == kAllReduceStreamSwitch) {
for (auto input : node->inputs()) {
if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) {
return iter;
if (AnfAlgo::GetCNodeName(node) != kAllReduceOpName) {
return iter;
} else if (ret == end) {
ret = iter;
}
}
}
}
}
return end;
return ret;
}
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,
......
......@@ -41,7 +41,8 @@ struct StreamSwitchNode {
if (offset < n.offset) {
return true;
} else if (offset == n.offset) {
return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false;
return (AnfAlgo::GetCNodeName(cnode) == kRecvOpName && AnfAlgo::GetCNodeName(n.cnode) == kSendOpName) ? false
: true;
} else {
return false;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册