From 390a86effb9e5d3fad6db21d6da89545492278de Mon Sep 17 00:00:00 2001 From: lichenever Date: Fri, 22 May 2020 09:37:36 +0800 Subject: [PATCH] fix gatherv2 --- .../parallel/ops_info/gather_v2_p_info.cc | 152 +++++++++++++----- .../parallel/ops_info/gather_v2_p_info.h | 2 +- tests/ut/python/parallel/test_gather_v2.py | 14 +- .../parallel/test_gather_v2_primitive.py | 4 +- 4 files changed, 126 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 5c7473cc9..81f276182 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -48,7 +48,7 @@ Status GatherV2PInfo::GetAttrs() { } Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; } else { @@ -84,12 +84,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - // Don't support repeated calc - auto params_strategy = strategy->GetInputDim().at(0); + // param_strategy(axis) != 1, index can't be splited + auto index_strategy = strategy->GetInputDim().at(1); + auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { + MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; + return FAILED; + } + + // param_strategy(axis) != 1, Don't support repeated calc CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - auto product = std::accumulate(params_strategy.begin(), params_strategy.end(), 1, std::multiplies()); - if (dev_num != IntToSize(product)) { + auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; return FAILED; } @@ -97,26 +104,66 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { return SUCCESS; } +Status GatherV2PInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + std::vector input_a_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b, op_for_axis; + if (input_a_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } else { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + mirror_ops_.push_back(op_for_axis); + + return SUCCESS; +} + Status GatherV2PInfo::InferDevMatrixShape() { dev_matrix_shape_.clear(); out_dev_matrix_shape_.clear(); // infer input dev_matrix_shape - auto params_strategy = strategy_->GetInputDim().at(0); - dev_matrix_shape_ = params_strategy; + auto param_strategy = strategy_->GetInputDim().at(0); + auto index_strategy = strategy_->GetInputDim().at(1); + dev_matrix_shape_ = param_strategy; + + // param_strategy(axis)!=1, + if (param_strategy.at(IntToSize(axis_)) != 1) { + std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); + } else { + dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); + } // infer out dev_matrix_shape // axis!=0, split axis - if (axis_ != 0 && params_strategy.at(IntToSize(axis_)) != 1) { - out_dev_matrix_shape_.push_back(params_strategy.at(0) * params_strategy.at(IntToSize(axis_))); - for (size_t i = 1; i < params_strategy.size(); ++i) { + if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) { + out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_))); + for (size_t i = 1; i < param_strategy.size(); ++i) { if (i == IntToSize(axis_)) { out_dev_matrix_shape_.push_back(1); } else { - out_dev_matrix_shape_.push_back(params_strategy.at(i)); + out_dev_matrix_shape_.push_back(param_strategy.at(i)); } } } else { - out_dev_matrix_shape_ = params_strategy; + out_dev_matrix_shape_ = dev_matrix_shape_; + } + auto product_out = + std::accumulate(out_dev_matrix_shape_.begin(), out_dev_matrix_shape_.end(), 1, std::multiplies()); + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (product_out == 1) { + out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), dev_num); } return SUCCESS; @@ -124,28 +171,56 @@ Status GatherV2PInfo::InferDevMatrixShape() { Status GatherV2PInfo::InferTensorMap() { // infer input tensor map + // param_strategy(axis) != 1 size_t param_size = inputs_shape_.at(0).size(); size_t index_size = inputs_shape_.at(1).size(); - std::vector tensor_map_index(index_size, -1); + size_t total_size = dev_matrix_shape_.size(); + std::vector tensor_map_index; std::vector tensor_map_params; - for (size_t i = 0; i < param_size; ++i) { - tensor_map_params.push_back(SizeToInt(param_size - i - 1)); + auto param_strategy = strategy_->GetInputDim().at(0); + if (param_strategy.at(IntToSize(axis_)) != 1) { + tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(i)); + } + } else { + // param_strategy(axis) == 1 + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(total_size - i - 1)); + } + for (size_t i = 0; i < index_size; ++i) { + tensor_map_index.push_back(SizeToInt(index_size - i - 1)); + } } // infer output tensor map std::vector tensor_map_out; - if (axis_ == 0) { - tensor_map_out.push_back(SizeToInt(param_size - 1)); - tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); - for (size_t i = 1; i < param_size; ++i) { - tensor_map_out.push_back(SizeToInt(param_size - i - 1)); - } - } else { + if (param_strategy.at(IntToSize(axis_)) == 1) { + // param_strategy(axis) == 1 for (size_t i = 0; i < param_size; ++i) { if (i == IntToSize(axis_)) { - tensor_map_out.insert(tensor_map_out.end(), index_size, -1); + for (size_t j = 0; j < index_size; ++j) { + tensor_map_out.push_back(SizeToInt(index_size - j - 1)); + } } else { - tensor_map_out.push_back(SizeToInt(param_size - i - 1)); + tensor_map_out.push_back(SizeToInt(total_size - i - 1)); + } + } + } else { + // param_strategy(axis) != 1 + if (axis_ == 0) { + tensor_map_out.insert(tensor_map_out.end(), 0); + tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); + for (size_t i = 1; i < param_size; ++i) { + tensor_map_out.push_back(i); + } + } else { + for (size_t i = 0; i < param_size; ++i) { + if (i == IntToSize(axis_)) { + tensor_map_out.insert(tensor_map_out.end(), index_size, -1); + } else { + tensor_map_out.push_back(SizeToInt(param_size - i - 1)); + } } } } @@ -209,7 +284,12 @@ Status GatherV2PInfo::InferBias() { Status GatherV2PInfo::InferGroup() { std::vector group_list; - if (CreateGroupByDim(IntToSize(axis_), &group_list) != SUCCESS) { + auto param_strategy = strategy_->GetInputDim().at(0); + size_t dim = IntToSize(axis_); + if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { + dim = (axis_ + 1) % 2; + } + if (CreateGroupByDim(dim, &group_list) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Create group failed."; return FAILED; } @@ -231,7 +311,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); - auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), gen_g.virtual_input_node(), minimum}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); auto gather_v2 = gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); @@ -250,8 +330,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); OperatorAttrs attrs = {attr_op, attr_group}; auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); - std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1), - std::make_pair(equal, 2)}; + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; replace_graph_ = std::make_shared>, AnfNodePtr>>( std::make_pair(input_nodes, reduce_scatter)); @@ -309,11 +388,11 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { is_auto_parallel_ = true; Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; + Shape input1_split(inputs_shape_[1].size(), 1); + Shapes splittable_inputs = {input0_split, input1_split}; std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != - SUCCESS) { + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; return FAILED; } @@ -331,12 +410,13 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions strategy; - strategy.push_back(SizeToInt(dev_num)); - for (size_t i = 1; i < inputs_shape_[0].size(); i++) { - strategy.push_back(1); + Dimensions param_strategy(inputs_shape_[0].size(), 1); + Dimensions index_strategy; + index_strategy.push_back(SizeToInt(dev_num)); + for (size_t i = 1; i < inputs_shape_[1].size(); i++) { + index_strategy.push_back(1); } - std::vector strategy_v = {strategy}; + std::vector strategy_v = {param_strategy, index_strategy}; return std::make_shared>>(strategy_v); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index f05c3c171..a87b9838c 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -48,7 +48,7 @@ class GatherV2PInfo : public OperatorInfo { protected: Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } + Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 6d943be51..838b617fa 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -61,7 +61,7 @@ class Net(nn.Cell): def test_gatherv2_semi_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((1, 8),) + strategy1 = ((1, 8), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() @@ -73,7 +73,7 @@ def test_gatherv2_semi_auto0(): def test_gatherv2_semi_auto1(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8, 1),) + strategy1 = ((8, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() @@ -85,7 +85,7 @@ def test_gatherv2_semi_auto1(): def test_gatherv2_semi_auto2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((2, 4),) + strategy1 = ((2, 4), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() @@ -97,7 +97,7 @@ def test_gatherv2_semi_auto2(): def test_gatherv2_semi_auto3(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((1, 8),) + strategy1 = ((1, 8), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net.set_auto_parallel() @@ -109,7 +109,7 @@ def test_gatherv2_semi_auto3(): def test_gatherv2_semi_auto4(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8, 1),) + strategy1 = ((8, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net.set_auto_parallel() @@ -121,7 +121,7 @@ def test_gatherv2_semi_auto4(): def test_gatherv2_semi_auto5(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((2, 4),) + strategy1 = ((2, 4), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net.set_auto_parallel() @@ -155,7 +155,7 @@ def test_gatherv2_semi_auto7(): def test_gatherv2_semi_auto8(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8,),) + strategy1 = ((8,), (1, 1)) strategy2 = ((4, 2), (4, 2)) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() diff --git a/tests/ut/python/parallel/test_gather_v2_primitive.py b/tests/ut/python/parallel/test_gather_v2_primitive.py index a41692375..9c2d0958d 100644 --- a/tests/ut/python/parallel/test_gather_v2_primitive.py +++ b/tests/ut/python/parallel/test_gather_v2_primitive.py @@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel(): def test_axis1_batch_parallel(): - gather_v2_strategy = ((device_number, 1),) + gather_v2_strategy = ((device_number, 1), (1, )) criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) rank = 2 net_trains(gather_v2_strategy, criterion, rank) def test_axis1_strategy1(): - gather_v2_strategy = ((16, 2),) + gather_v2_strategy = ((16, 2), (1, )) rank = 17 criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) net_trains(gather_v2_strategy, criterion, rank) -- GitLab