From 61bf4b18a2d7ecef4b06e7d513c0e7e1b1d9568d Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 20 Aug 2020 15:48:01 +0800 Subject: [PATCH] fix_consecutive_allreduce_bug --- .../runtime/device/gpu/gpu_stream_assign.cc | 25 +++++++++++++------ .../runtime/device/gpu/gpu_stream_assign.h | 3 ++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc index 78915f10d..38d100f2a 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc @@ -74,9 +74,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr &k MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; continue; } - SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, - IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; - send_recv_pairs->push_back(pair1); + if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) { + SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, + IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; + send_recv_pairs->push_back(pair1); + } // Find node which uses AllReduce as input[0]. std::vector::iterator mock_recv_node_iter = FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); @@ -84,9 +86,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr &k MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; return false; } - SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), - IntToSize(mock_recv_node_iter - iter_begin)}; - send_recv_pairs->push_back(pair2); + if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) { + SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), + IntToSize(mock_recv_node_iter - iter_begin)}; + send_recv_pairs->push_back(pair2); + } } } return true; @@ -110,17 +114,22 @@ std::vector::iterator FindRecvNodePos(std::vector::iterator std::vector::iterator end, const CNodePtr mock_send_node, StreamSwitchType stream_switch_type) { MS_EXCEPTION_IF_NULL(mock_send_node); + auto ret = end; for (auto iter = begin; iter != end; iter++) { auto node = *iter; if (stream_switch_type == kAllReduceStreamSwitch) { for (auto input : node->inputs()) { if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) { - return iter; + if (AnfAlgo::GetCNodeName(node) != kAllReduceOpName) { + return iter; + } else if (ret == end) { + ret = iter; + } } } } } - return end; + return ret; } void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h index f31bc7aa1..e4a8f872f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h @@ -41,7 +41,8 @@ struct StreamSwitchNode { if (offset < n.offset) { return true; } else if (offset == n.offset) { - return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false; + return (AnfAlgo::GetCNodeName(cnode) == kRecvOpName && AnfAlgo::GetCNodeName(n.cnode) == kSendOpName) ? false + : true; } else { return false; } -- GitLab