From 9c8e750c9e9b24bd9f042b64b0869bf2bb2a7f8a Mon Sep 17 00:00:00 2001 From: hongxing Date: Wed, 20 May 2020 03:21:29 +0200 Subject: [PATCH] maximize strategy dynamically --- .../rec_core/rec_generate_strategy.cc | 93 ++++++++++++------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index c63bb64f5..6b5bb9720 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -81,16 +81,33 @@ std::vector> PrepareVirtualDataset(const std::vector> PrepareBiasAdd(const std::vector> &ops, const size_t iter_ops, std::vector s) { std::vector> strategies; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 1) { - auto max = s[max_element(s.begin(), s.end()) - s.begin()]; - std::vector s_single; - s_single.push_back(max); - strategies.push_back(s_single); - continue; + + auto dev_num = g_device_manager->DeviceNum(); + size_t cut_num = 1; + for (size_t iter_s = 0; iter_s < s.size(); iter_s++) { + cut_num *= s[iter_s]; + } + if (cut_num != dev_num) { + std::vector s_max = s; + for (size_t dim = 0; dim < (size_t)ops[iter_ops]->inputs_tensor_info()[0].shape().size(); dim++) { + size_t shape = ops[iter_ops]->inputs_tensor_info()[0].shape()[dim] / s[dim]; + while (cut_num < dev_num && shape % 2 == 0) { + shape = shape / 2; + s_max[dim] = s_max[dim] * 2; + cut_num = cut_num * 2; + } + if (cut_num == dev_num) { + break; + } } - strategies.push_back(s); + s = s_max; } + + strategies.push_back(s); + std::vector s_biasadd; + s_biasadd.push_back(s[1]); + strategies.push_back(s_biasadd); + return strategies; } @@ -423,36 +440,48 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect } auto dev_num = g_device_manager->DeviceNum(); - size_t cut_num = 1; - for (size_t i = 0; i < s.size(); i++) { - cut_num *= s[i]; - } - if (cut_num < dev_num) { - size_t diff = dev_num / cut_num; - if (s[0] * diff > dev_num) { - MS_LOG(EXCEPTION) << "Failure: Can not continue to partition in the N-dimension of the element-wise operator."; - } - s[0] = s[0] * diff; - } - - for (size_t i = 0; i < (size_t)ops[iter_ops]->inputs_tensor_info().size(); i++) { - if (ops[iter_ops]->inputs_tensor_info()[i].shape().size() == 0) { + for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); + iter_op_inputs++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { stra.push_back(s_empty); continue; } - std::vector s_1 = s; - bool modified = false; - for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[i].shape().size(); j++) { - if (ops[iter_ops]->inputs_tensor_info()[i].shape()[j] == 1) { - s_1[j] = 1; - modified = true; + + size_t cut_num = 1; + for (size_t iter_s = 0; iter_s < s.size(); iter_s++) { + cut_num *= s[iter_s]; + } + if (cut_num == dev_num) { + std::vector s_1 = s; + bool modified = false; + for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { + s_1[j] = 1; + modified = true; + } + } + if (modified) { + stra.push_back(s_1); + } else { + stra.push_back(s); } + continue; } - if (modified) { - stra.push_back(s_1); - } else { - stra.push_back(s); + + std::vector s_max = s; + for (size_t dim = 0; dim < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); dim++) { + size_t shape = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[dim] / s[dim]; + while (cut_num < dev_num && shape % 2 == 0) { + shape = shape / 2; + s_max[dim] = s_max[dim] * 2; + cut_num = cut_num * 2; + } + if (cut_num == dev_num) { + break; + } } + + stra.push_back(s_max); } return stra; } -- GitLab