diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 68b776155ac2bc798c488bd70d605329c76d4783..c637870d92c08e02476332fc45ea20704363bc34 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -614,7 +614,6 @@ std::vector CopyIncomingOperatorInputStrategy(const std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, std::vector basic_stra) { - std::vector s_empty = {}; std::vector> stra; MS_EXCEPTION_IF_NULL(ops[iter_ops]); @@ -636,9 +635,99 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect if (ops[iter_ops]->type() == L2_NORMALIZE) { return PrepareL2Normalize(ops, iter_ops, basic_stra); } + if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL || + ops[iter_ops]->type() == DIV) { + return CheckBroadcast(ops, iter_ops, basic_stra); + } + + return CheckDivisible(ops, iter_ops, basic_stra); +} + +// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc. +std::vector> CheckBroadcast(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + std::vector> stra; + + size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size(); + + // Do Broadcasting in the second tensor. + if (second_tensor_dim < first_tensor_dim) { + bool braoadcast_first_tensor = false; + // Push back the first tensor's strategy. + stra.push_back(s); + // Push back the second tensor's strategy after applying broadcast. + stra.push_back(ApplyBroadcast(ops, iter_ops, s, second_tensor_dim, first_tensor_dim, braoadcast_first_tensor)); + } else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor. + bool braoadcast_first_tensor = true; + // Push back the first tensor's strategy after applying broadcast. + stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, braoadcast_first_tensor)); + // Push back the second tensor's strategy. + stra.push_back(s); + } else { // Broadcasting can be ignored or No broadcasting needs to be applied. + stra = CheckDivisible(ops, iter_ops, s); + } + + return stra; +} + +std::vector ApplyBroadcast(const std::vector> &ops, const size_t iter_ops, + std::vector s, size_t target_tensor_dim, size_t refer_tensor_dim, + bool braoadcast_first_tensor) { + std::vector s_empty = {}; + std::vector s_broadcast; + int target_tensor_index = 0; + int refer_tensor_index = 0; + + // Indexing target and refer tensor. + if (braoadcast_first_tensor) { + target_tensor_index = 0; + refer_tensor_index = 1; + } else { + target_tensor_index = 1; + refer_tensor_index = 0; + } + + // When target tensor with an empty dim. + if (target_tensor_dim == 0) { + return s_empty; + } else if (target_tensor_dim == 1) { // When target tensor with a single dim. + bool broadcast_dim_found = false; + for (size_t iter = 0; iter < refer_tensor_dim; iter++) { + // Find and copy that dim's strategy from the refer tensor. + if ((ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] == + ops[iter_ops]->inputs_tensor_info()[target_tensor_index].shape()[0]) && + (ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] > 1) && + (refer_tensor_dim == s.size())) { + s_broadcast.push_back(s.at(iter)); + broadcast_dim_found = true; + break; + } + } + // Cannot decide which dim it is, push back one. + if (broadcast_dim_found == false) { + s_broadcast.push_back(1); + } + } else { + // Cannot decide which dim needs to do broadcast, push back one(s). + for (size_t iter = 0; iter < target_tensor_dim; iter++) { + s_broadcast.push_back(1); + } + } + return s_broadcast; +} + +// Check whether the operator can be divided by the current strategy. +std::vector> CheckDivisible(const std::vector> &ops, + const size_t iter_ops, std::vector basic_stra) { + std::vector s_empty = {}; + std::vector> stra; + + // For all the input tensors. for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + // If input tensor is empty, return strategy as void. if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { stra.push_back(s_empty); continue; @@ -646,6 +735,8 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect std::vector tmp_stra = basic_stra; bool modified = false; + + // Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead. for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { tmp_stra[j] = 1; @@ -658,6 +749,7 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect stra.push_back(basic_stra); } } + return stra; } diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 9acd05e0a98281ee3759ca36b5399a20e435f53f..bd8de641a22b1dd7d9a8b0c219862a992418a0ca 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -42,6 +42,13 @@ std::vector> PrepareGatherV2(const std::vector s); std::vector> PrepareL2Normalize(const std::vector> &ops, const size_t iter_ops, std::vector s); +std::vector> CheckBroadcast(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector ApplyBroadcast(const std::vector> &ops, const size_t iter_ops, + std::vector s, size_t target_tensor_dim, size_t refer_tensor_dim, + bool braoadcast_first_tensor); +std::vector> CheckDivisible(const std::vector> &ops, + const size_t iter_ops, std::vector s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops);