提交 390a86ef 编写于 作者: L lichenever

fix gatherv2

上级 9f079d44
...@@ -48,7 +48,7 @@ Status GatherV2PInfo::GetAttrs() { ...@@ -48,7 +48,7 @@ Status GatherV2PInfo::GetAttrs() {
} }
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy."; MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
} else { } else {
...@@ -84,12 +84,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -84,12 +84,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
// Don't support repeated calc // param_strategy(axis) != 1, index can't be splited
auto params_strategy = strategy->GetInputDim().at(0); auto index_strategy = strategy->GetInputDim().at(1);
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) {
MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
return FAILED;
}
// param_strategy(axis) != 1, Don't support repeated calc
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto product = std::accumulate(params_strategy.begin(), params_strategy.end(), 1, std::multiplies<int>()); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
if (dev_num != IntToSize(product)) { if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED; return FAILED;
} }
...@@ -97,26 +104,66 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -97,26 +104,66 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }
Status GatherV2PInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
std::vector<Group> input_a_group;
if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input a failed.";
return FAILED;
}
OperatorVector op_for_input_a, op_for_input_b, op_for_axis;
if (input_a_group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror group is empty.";
return SUCCESS;
} else {
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
}
mirror_ops_.push_back(op_for_input_a);
mirror_ops_.push_back(op_for_input_b);
mirror_ops_.push_back(op_for_axis);
return SUCCESS;
}
Status GatherV2PInfo::InferDevMatrixShape() { Status GatherV2PInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear(); dev_matrix_shape_.clear();
out_dev_matrix_shape_.clear(); out_dev_matrix_shape_.clear();
// infer input dev_matrix_shape // infer input dev_matrix_shape
auto params_strategy = strategy_->GetInputDim().at(0); auto param_strategy = strategy_->GetInputDim().at(0);
dev_matrix_shape_ = params_strategy; auto index_strategy = strategy_->GetInputDim().at(1);
dev_matrix_shape_ = param_strategy;
// param_strategy(axis)!=1,
if (param_strategy.at(IntToSize(axis_)) != 1) {
std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end());
} else {
dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end());
}
// infer out dev_matrix_shape // infer out dev_matrix_shape
// axis!=0, split axis // axis!=0, split axis
if (axis_ != 0 && params_strategy.at(IntToSize(axis_)) != 1) { if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) {
out_dev_matrix_shape_.push_back(params_strategy.at(0) * params_strategy.at(IntToSize(axis_))); out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_)));
for (size_t i = 1; i < params_strategy.size(); ++i) { for (size_t i = 1; i < param_strategy.size(); ++i) {
if (i == IntToSize(axis_)) { if (i == IntToSize(axis_)) {
out_dev_matrix_shape_.push_back(1); out_dev_matrix_shape_.push_back(1);
} else { } else {
out_dev_matrix_shape_.push_back(params_strategy.at(i)); out_dev_matrix_shape_.push_back(param_strategy.at(i));
} }
} }
} else { } else {
out_dev_matrix_shape_ = params_strategy; out_dev_matrix_shape_ = dev_matrix_shape_;
}
auto product_out =
std::accumulate(out_dev_matrix_shape_.begin(), out_dev_matrix_shape_.end(), 1, std::multiplies<int>());
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
if (product_out == 1) {
out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), dev_num);
} }
return SUCCESS; return SUCCESS;
...@@ -124,28 +171,56 @@ Status GatherV2PInfo::InferDevMatrixShape() { ...@@ -124,28 +171,56 @@ Status GatherV2PInfo::InferDevMatrixShape() {
Status GatherV2PInfo::InferTensorMap() { Status GatherV2PInfo::InferTensorMap() {
// infer input tensor map // infer input tensor map
// param_strategy(axis) != 1
size_t param_size = inputs_shape_.at(0).size(); size_t param_size = inputs_shape_.at(0).size();
size_t index_size = inputs_shape_.at(1).size(); size_t index_size = inputs_shape_.at(1).size();
std::vector<int32_t> tensor_map_index(index_size, -1); size_t total_size = dev_matrix_shape_.size();
std::vector<int32_t> tensor_map_index;
std::vector<int32_t> tensor_map_params; std::vector<int32_t> tensor_map_params;
for (size_t i = 0; i < param_size; ++i) { auto param_strategy = strategy_->GetInputDim().at(0);
tensor_map_params.push_back(SizeToInt(param_size - i - 1)); if (param_strategy.at(IntToSize(axis_)) != 1) {
tensor_map_index.insert(tensor_map_index.begin(), index_size, -1);
for (size_t i = 0; i < param_size; ++i) {
tensor_map_params.push_back(SizeToInt(i));
}
} else {
// param_strategy(axis) == 1
for (size_t i = 0; i < param_size; ++i) {
tensor_map_params.push_back(SizeToInt(total_size - i - 1));
}
for (size_t i = 0; i < index_size; ++i) {
tensor_map_index.push_back(SizeToInt(index_size - i - 1));
}
} }
// infer output tensor map // infer output tensor map
std::vector<int32_t> tensor_map_out; std::vector<int32_t> tensor_map_out;
if (axis_ == 0) { if (param_strategy.at(IntToSize(axis_)) == 1) {
tensor_map_out.push_back(SizeToInt(param_size - 1)); // param_strategy(axis) == 1
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1);
for (size_t i = 1; i < param_size; ++i) {
tensor_map_out.push_back(SizeToInt(param_size - i - 1));
}
} else {
for (size_t i = 0; i < param_size; ++i) { for (size_t i = 0; i < param_size; ++i) {
if (i == IntToSize(axis_)) { if (i == IntToSize(axis_)) {
tensor_map_out.insert(tensor_map_out.end(), index_size, -1); for (size_t j = 0; j < index_size; ++j) {
tensor_map_out.push_back(SizeToInt(index_size - j - 1));
}
} else { } else {
tensor_map_out.push_back(SizeToInt(param_size - i - 1)); tensor_map_out.push_back(SizeToInt(total_size - i - 1));
}
}
} else {
// param_strategy(axis) != 1
if (axis_ == 0) {
tensor_map_out.insert(tensor_map_out.end(), 0);
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1);
for (size_t i = 1; i < param_size; ++i) {
tensor_map_out.push_back(i);
}
} else {
for (size_t i = 0; i < param_size; ++i) {
if (i == IntToSize(axis_)) {
tensor_map_out.insert(tensor_map_out.end(), index_size, -1);
} else {
tensor_map_out.push_back(SizeToInt(param_size - i - 1));
}
} }
} }
} }
...@@ -209,7 +284,12 @@ Status GatherV2PInfo::InferBias() { ...@@ -209,7 +284,12 @@ Status GatherV2PInfo::InferBias() {
Status GatherV2PInfo::InferGroup() { Status GatherV2PInfo::InferGroup() {
std::vector<Group> group_list; std::vector<Group> group_list;
if (CreateGroupByDim(IntToSize(axis_), &group_list) != SUCCESS) { auto param_strategy = strategy_->GetInputDim().at(0);
size_t dim = IntToSize(axis_);
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
dim = (axis_ + 1) % 2;
}
if (CreateGroupByDim(dim, &group_list) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed."; MS_LOG(ERROR) << name_ << ": Create group failed.";
return FAILED; return FAILED;
} }
...@@ -231,7 +311,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { ...@@ -231,7 +311,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), gen_g.virtual_input_node(), minimum}); auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
auto gather_v2 = auto gather_v2 =
gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
...@@ -250,8 +330,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { ...@@ -250,8 +330,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
OperatorAttrs attrs = {attr_op, attr_group}; OperatorAttrs attrs = {attr_op, attr_group};
auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1), std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
std::make_pair(equal, 2)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
std::make_pair(input_nodes, reduce_scatter)); std::make_pair(input_nodes, reduce_scatter));
...@@ -309,11 +388,11 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { ...@@ -309,11 +388,11 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
is_auto_parallel_ = true; is_auto_parallel_ = true;
Shape input0_split(inputs_shape_[0].size(), 1); Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split}; Shape input1_split(inputs_shape_[1].size(), 1);
Shapes splittable_inputs = {input0_split, input1_split};
std::vector<StrategyPtr> sp_vector; std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
SUCCESS) {
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
return FAILED; return FAILED;
} }
...@@ -331,12 +410,13 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { ...@@ -331,12 +410,13 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() { std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy; Dimensions param_strategy(inputs_shape_[0].size(), 1);
strategy.push_back(SizeToInt(dev_num)); Dimensions index_strategy;
for (size_t i = 1; i < inputs_shape_[0].size(); i++) { index_strategy.push_back(SizeToInt(dev_num));
strategy.push_back(1); for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
index_strategy.push_back(1);
} }
std::vector<Dimensions> strategy_v = {strategy}; std::vector<Dimensions> strategy_v = {param_strategy, index_strategy};
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v);
} }
} // namespace parallel } // namespace parallel
......
...@@ -48,7 +48,7 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -48,7 +48,7 @@ class GatherV2PInfo : public OperatorInfo {
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override { return SUCCESS; } Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override; Status InferTensorInfo() override;
Status InferDevMatrixShape() override; Status InferDevMatrixShape() override;
......
...@@ -61,7 +61,7 @@ class Net(nn.Cell): ...@@ -61,7 +61,7 @@ class Net(nn.Cell):
def test_gatherv2_semi_auto0(): def test_gatherv2_semi_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8),) strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
...@@ -73,7 +73,7 @@ def test_gatherv2_semi_auto0(): ...@@ -73,7 +73,7 @@ def test_gatherv2_semi_auto0():
def test_gatherv2_semi_auto1(): def test_gatherv2_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1),) strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
...@@ -85,7 +85,7 @@ def test_gatherv2_semi_auto1(): ...@@ -85,7 +85,7 @@ def test_gatherv2_semi_auto1():
def test_gatherv2_semi_auto2(): def test_gatherv2_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4),) strategy1 = ((2, 4), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
...@@ -97,7 +97,7 @@ def test_gatherv2_semi_auto2(): ...@@ -97,7 +97,7 @@ def test_gatherv2_semi_auto2():
def test_gatherv2_semi_auto3(): def test_gatherv2_semi_auto3():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8),) strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
...@@ -109,7 +109,7 @@ def test_gatherv2_semi_auto3(): ...@@ -109,7 +109,7 @@ def test_gatherv2_semi_auto3():
def test_gatherv2_semi_auto4(): def test_gatherv2_semi_auto4():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1),) strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
...@@ -121,7 +121,7 @@ def test_gatherv2_semi_auto4(): ...@@ -121,7 +121,7 @@ def test_gatherv2_semi_auto4():
def test_gatherv2_semi_auto5(): def test_gatherv2_semi_auto5():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4),) strategy1 = ((2, 4), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
...@@ -155,7 +155,7 @@ def test_gatherv2_semi_auto7(): ...@@ -155,7 +155,7 @@ def test_gatherv2_semi_auto7():
def test_gatherv2_semi_auto8(): def test_gatherv2_semi_auto8():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8,),) strategy1 = ((8,), (1, 1))
strategy2 = ((4, 2), (4, 2)) strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel() net.set_auto_parallel()
......
...@@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel(): ...@@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel():
def test_axis1_batch_parallel(): def test_axis1_batch_parallel():
gather_v2_strategy = ((device_number, 1),) gather_v2_strategy = ((device_number, 1), (1, ))
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2 rank = 2
net_trains(gather_v2_strategy, criterion, rank) net_trains(gather_v2_strategy, criterion, rank)
def test_axis1_strategy1(): def test_axis1_strategy1():
gather_v2_strategy = ((16, 2),) gather_v2_strategy = ((16, 2), (1, ))
rank = 17 rank = 17
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
net_trains(gather_v2_strategy, criterion, rank) net_trains(gather_v2_strategy, criterion, rank)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册