diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index d62111c0107be9f6e82f8866eb73b8b1540f69a1..680d6f3ed6637f98800646e72adcd9271874dfbd 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -438,12 +438,9 @@ std::vector GetRankFromGroup(const Group &group) { Status GatherV2PInfo::InferForwardCommunication() { forward_op_.clear(); - if (target_ != CPU) { - return SUCCESS; - } auto param_strategy = strategy_->GetInputDim().at(0); - // don't split axis, no need forward communication - if (param_strategy.at(IntToSize(axis_)) == 1) { + // don't split axis or target is not CPU, no need forward communication + if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) { return SUCCESS; } // split axis