提交 9940c723 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3980 [AutoParallel] add GatherV2P strategy analysis for W&D

Merge pull request !3980 from Chong/wd
...@@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, ...@@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
s[axis] = 1; s[axis] = 1;
strategies.push_back(s); strategies.push_back(s);
auto pos = ops[iter_ops]->name().find("Info"); return strategies;
auto name = ops[iter_ops]->name().substr(0, pos); }
if (name == "GatherV2") {
return strategies; Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
Strategys strategies;
auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
Dimensions index(output_shape.size() - 1, 0);
for (size_t i = 0; i < index.size(); i++) {
index[i] = i;
} }
std::sort(index.begin(), index.end(),
[&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; });
index.insert(index.begin(), 0);
Dimensions s_indices; Dimensions strategie(output_shape.size(), 1);
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { size_t num_device = g_device_manager->DeviceNum();
s_indices.push_back(1); size_t cut = 1;
for (size_t i = 0; i < index.size(); i++) {
while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) {
output_shape[index[i]] /= 2;
cut *= 2;
strategie[index[i]] *= 2;
}
if (cut == num_device) {
break;
}
}
auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2));
if (axis_input < 0) {
axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
}
int32_t axis = axis_input;
if (axis >= SizeToInt(s.size())) {
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
}
if (axis == 0) {
s.clear();
s.push_back(1);
for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) {
s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]);
}
strategies.push_back(s);
s.clear();
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
s.push_back(strategie[i]);
}
strategies.push_back(s);
} else if (axis == 1) {
s.clear();
s.push_back(strategie[0]);
s.push_back(1);
strategies.push_back(s);
s.clear();
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]);
}
strategies.push_back(s);
} else {
MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1.";
} }
strategies.push_back(s_indices);
return strategies; return strategies;
} }
Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index) {
auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape();
Dimensions index(output_shape.size() - 1, 0);
for (size_t i = 0; i < index.size(); i++) {
index[i] = i;
}
std::sort(index.begin(), index.end(),
[&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; });
index.insert(index.begin(), 0);
Dimensions strategie(output_shape.size(), 1);
size_t num_device = g_device_manager->DeviceNum();
size_t cut = 1;
for (size_t i = 0; i < index.size(); i++) {
while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) {
output_shape[index[i]] /= 2;
cut *= 2;
strategie[index[i]] *= 2;
}
if (cut == num_device) {
break;
}
}
return strategie;
}
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions s) { Dimensions s) {
int32_t axis = 0; int32_t axis = 0;
...@@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap ...@@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap
Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index) { const size_t incoming_op_index) {
Dimensions s; Dimensions s;
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) {
ops[incoming_op_index]->type() == TRANSPOSE) {
return s; return s;
} }
if (ops[incoming_op_index]->type() == GATHERV2) {
auto pos = ops[incoming_op_index]->name().find("Info");
auto name = ops[incoming_op_index]->name().substr(0, pos);
if (name == "GatherV2") {
return s;
} else if (name == "GatherV2P") {
return PrepareGatherV2POutputStrategy(ops, incoming_op_index);
} else {
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
}
}
auto strategy = ops[incoming_op_index]->selected_strategy(); auto strategy = ops[incoming_op_index]->selected_strategy();
if (strategy->GetInputNumber() == 0) { if (strategy->GetInputNumber() == 0) {
return s; return s;
...@@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, con ...@@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, con
if (input_value.back()->isa<ValueTuple>()) { if (input_value.back()->isa<ValueTuple>()) {
auto attr_axis = GetValue<std::vector<int>>(input_value.back()); auto attr_axis = GetValue<std::vector<int>>(input_value.back());
if (attr_axis.empty()) { if (attr_axis.empty()) {
MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; for (size_t i = 0; i < input_dim; i++) {
} dim_list.push_back(SizeToInt(i));
for (auto &axis : attr_axis) { }
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); } else {
for (auto &axis : attr_axis) {
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
}
} }
} else if (input_value.back()->isa<Int32Imm>()) { } else if (input_value.back()->isa<Int32Imm>()) {
int axis = GetValue<int>(input_value.back()); int axis = GetValue<int>(input_value.back());
...@@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera ...@@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
return PrepareBiasAdd(s_ptr); return PrepareBiasAdd(s_ptr);
} }
if (ops[iter_ops]->type() == GATHERV2) { if (ops[iter_ops]->type() == GATHERV2) {
return PrepareGatherV2(ops, iter_ops, basic_stra); auto pos = ops[iter_ops]->name().find("Info");
auto name = ops[iter_ops]->name().substr(0, pos);
if (name == "GatherV2") {
return PrepareGatherV2(ops, iter_ops, basic_stra);
} else if (name == "GatherV2P") {
return PrepareGatherV2P(ops, iter_ops, basic_stra);
} else {
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
}
} }
if (ops[iter_ops]->type() == L2_NORMALIZE) { if (ops[iter_ops]->type() == L2_NORMALIZE) {
return PrepareL2Normalize(ops, iter_ops, basic_stra); return PrepareL2Normalize(ops, iter_ops, basic_stra);
......
...@@ -37,6 +37,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); ...@@ -37,6 +37,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index);
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions s); Dimensions s);
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册