diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 8c99df8345b3d34a4a3e3302ccfcc23b508e9f4b..e03d917b2700ad8b62069e9e4f0365460ba878fc 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -39,17 +39,18 @@ void GenerateStrategy(std::shared_ptr graph, const std::vector> no_stra_op_list(new std::vector); GenerateEliminatedOperatorStrategyForward(graph, ops, eli_list, input_tensor_names, index_list, no_stra_op_list); GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); + GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); } std::vector> PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { std::vector> strategies; + auto attrs = ops[iter_ops]->attrs(); + bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); + bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { std::vector s; - auto attrs = ops[iter_ops]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); if (transpose_a && (iter_op_inputs == 0)) { s.push_back( static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); @@ -71,43 +72,20 @@ std::vector> PrepareMatMul(const std::shared_ptr &gr return strategies; } -std::vector> PrepareVirtualDataset(const std::vector> &ops, - const size_t iter_ops) { - std::vector> strategies = MakeDataParallelStrategy(ops, iter_ops); - strategies[1][0] = strategies[0][0]; +std::vector> PreparePReLU(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); + strategies[1][0] = 1; return strategies; } -std::vector> PrepareScalarInputOperator(const std::vector> &ops, - const size_t iter_ops, std::vector s) { +std::vector> PrepareBiasAdd(std::vector s) { std::vector> strategies; - - auto dev_num = g_device_manager->DeviceNum(); - size_t cut_num = 1; - for (size_t iter_s = 0; iter_s < s.size(); iter_s++) { - cut_num *= s[iter_s]; - } - if (cut_num != dev_num) { - std::vector s_max = s; - for (size_t dim = 0; dim < (size_t)ops[iter_ops]->inputs_tensor_info()[0].shape().size(); dim++) { - size_t shape = ops[iter_ops]->inputs_tensor_info()[0].shape()[dim] / s[dim]; - while (cut_num < dev_num && shape % 2 == 0) { - shape = shape / 2; - s_max[dim] = s_max[dim] * 2; - cut_num = cut_num * 2; - } - if (cut_num == dev_num) { - break; - } - } - s = s_max; - } - strategies.push_back(s); std::vector s_biasadd; s_biasadd.push_back(s[1]); strategies.push_back(s_biasadd); - return strategies; } @@ -131,16 +109,13 @@ std::vector> MakeRecSearchStrategy(const std::shared_ptrstrategy(); - std::vector> strategies; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; } - // size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size(); size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - std::vector s; if (output_size == 4) { s.push_back( @@ -164,14 +139,14 @@ std::vector> MakeRecSearchStrategy(const std::shared_ptr> MakeDataParallelStrategy(const std::vector> &ops, - const size_t iter_ops) { +std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -180,8 +155,9 @@ std::vector> MakeDataParallelStrategy(const std::vectorstrategy(); - std::vector> strategies; + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; @@ -192,8 +168,6 @@ std::vector> MakeDataParallelStrategy(const std::vectorDeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; s.push_back(std::min(max_device_num, target_tensor_batch)); } else { s.push_back(1); @@ -202,9 +176,21 @@ std::vector> MakeDataParallelStrategy(const std::vectornodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch); + } + return strategies; } @@ -217,20 +203,18 @@ std::vector> PrepareStrategy(const std::shared_ptr & if (iter_ops >= ops.size()) { MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; } + MS_EXCEPTION_IF_NULL(ops[iter_ops]); auto type = ops[iter_ops]->type(); - if (type == VIRTUAL_DATA_SET) { - return PrepareVirtualDataset(ops, iter_ops); - } auto idx = DictOpType.find(type); if (idx == DictOpType.end()) { - return MakeDataParallelStrategy(ops, iter_ops); + return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); } if (type == MATMUL) { return PrepareMatMul(graph, ops, iter_graph, iter_ops); - } else if (type == RESHAPE) { - return MakeDataParallelStrategy(ops, iter_ops); + } else if (type == PRELU) { + return PreparePReLU(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); } @@ -242,28 +226,25 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { std::vector> strategies; size_t iter_graph = index_list->at(iter_ops); - if (iter_graph == SIZE_MAX) { - StrategyPtr sp = std::make_shared(0, strategies); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - continue; + if (iter_graph != SIZE_MAX) { + strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); } - strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); StrategyPtr sp = std::make_shared(0, strategies); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); } } -int FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, - const size_t iter_ops) { - int incoming_op_index = -1; - for (size_t i = 1; i < (size_t)input_tensor_names[iter_ops].size(); i++) { - for (size_t j = 0; j < (size_t)input_tensor_names.size(); j++) { +size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, + const size_t iter_ops) { + size_t incoming_op_index = SIZE_MAX; + for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) { + for (size_t j = 0; j < input_tensor_names.size(); j++) { if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) { incoming_op_index = j; break; } } - if (incoming_op_index != -1) { + if (incoming_op_index != SIZE_MAX) { break; } } @@ -298,12 +279,16 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const int incoming_op_index) { + const size_t incoming_op_index) { std::vector s; + if (ops[incoming_op_index]->type() == RESHAPE) { + return s; + } auto strategy = ops[incoming_op_index]->selected_strategy(); if (strategy->GetInputNumber() == 0) { return s; } + for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) { if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) { continue; @@ -327,6 +312,7 @@ std::vector GetAxisList(const std::vector } else { MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl; } + for (auto &element : elements) { if (!element->isa()) { MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl; @@ -338,12 +324,13 @@ std::vector GetAxisList(const std::vector } std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const int incoming_op_index, std::vector s) { + const size_t incoming_op_index, std::vector s) { std::vector s_Squeeze; std::vector stra_dim_list; for (size_t i = 0; i < s.size(); i++) { stra_dim_list.push_back(i); } + auto axis_list = GetAxisList(ops, incoming_op_index); for (auto axis : axis_list) { auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis); @@ -355,6 +342,7 @@ std::vector ModifyStrategyIfSqueezeIncoming(const std::vector GetDimList(const std::vector> } std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const int incoming_op_index, std::vector s) { + const size_t incoming_op_index, std::vector s) { std::vector s_Reduce; std::vector axis_list; for (size_t i = 0; i < s.size(); i++) { axis_list.push_back(i); } + auto dim_list = GetDimList(ops, incoming_op_index); for (auto axis : dim_list) { auto it = find(axis_list.begin(), axis_list.end(), axis); @@ -405,6 +394,7 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector ModifyStrategyIfReduceIncoming(const std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const int incoming_op_index, const size_t iter_ops, - const std::shared_ptr> no_stra_op_list) { + const size_t iter_ops, const size_t incoming_op_index) { std::vector s; s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); + if (s.size() != 0) { if (ops[incoming_op_index]->type() == SQUEEZE) { s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); @@ -429,27 +419,27 @@ std::vector CopyIncomingOperatorInputStrategy(const std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, std::vector s) { + const size_t iter_ops, + std::vector basic_stra) { std::vector s_empty = {}; std::vector> stra; + MS_EXCEPTION_IF_NULL(ops[iter_ops]); - if (s.size() == 0) { + if (basic_stra.size() == 0) { for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - stra.push_back(s); + stra.push_back(basic_stra); } return stra; } - MS_EXCEPTION_IF_NULL(ops[iter_ops]); - if (ops[iter_ops]->type() == BIAS_ADD || ops[iter_ops]->type() == PRELU) { - return PrepareScalarInputOperator(ops, iter_ops, s); + if (ops[iter_ops]->type() == BIAS_ADD) { + return PrepareBiasAdd(basic_stra); } if (ops[iter_ops]->type() == ONEHOT) { - return PrepareOneHot(s); + return PrepareOneHot(basic_stra); } - auto dev_num = g_device_manager->DeviceNum(); for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { @@ -457,41 +447,19 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect continue; } - size_t cut_num = 1; - for (size_t iter_s = 0; iter_s < s.size(); iter_s++) { - cut_num *= s[iter_s]; - } - if (cut_num == dev_num) { - std::vector s_1 = s; - bool modified = false; - for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { - s_1[j] = 1; - modified = true; - } + std::vector tmp_stra = basic_stra; + bool modified = false; + for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { + tmp_stra[j] = 1; + modified = true; } - if (modified) { - stra.push_back(s_1); - } else { - stra.push_back(s); - } - continue; } - - std::vector s_max = s; - for (size_t dim = 0; dim < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); dim++) { - size_t shape = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[dim] / s[dim]; - while (cut_num < dev_num && shape % 2 == 0) { - shape = shape / 2; - s_max[dim] = s_max[dim] * 2; - cut_num = cut_num * 2; - } - if (cut_num == dev_num) { - break; - } + if (modified) { + stra.push_back(tmp_stra); + } else { + stra.push_back(basic_stra); } - - stra.push_back(s_max); } return stra; } @@ -502,17 +470,17 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr grap const std::vector> &input_tensor_names, const std::shared_ptr> index_list, const std::shared_ptr> no_stra_op_list) { - for (int eli_index = eli_list->size() - 1; eli_index >= 0; eli_index--) { - size_t iter_ops = eli_list->at(eli_index)[0]; + for (size_t eli_index = eli_list->size(); eli_index > 0; eli_index--) { + size_t iter_ops = eli_list->at(eli_index - 1)[0]; std::vector> stra; std::vector s; - int incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); - if (incoming_op_index != -1) { + size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); + if (incoming_op_index != SIZE_MAX && ops[iter_ops]->type() != ONEHOT) { auto iter_graph = index_list->at(incoming_op_index); if (iter_graph != SIZE_MAX) { s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); } else { - s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list); + s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index); } } @@ -534,7 +502,7 @@ std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector ModifyStrategyIfReduceOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector dim_list = GetDimList(ops, iter_ops); - if (dim_list.size() == 0) { - return s; + size_t cut = 1; + for (size_t i = 0; i < s_Squeeze.size(); i++) { + cut *= s_Squeeze[i]; } - std::vector s_Reduce; - size_t s_index = 0; - size_t dim_list_index = 0; - for (size_t i = 0; i < (size_t)(s.size() + dim_list.size()); i++) { - if (i == (size_t)dim_list[dim_list_index]) { - s_Reduce.push_back(1); - dim_list_index++; - } else { - s_Reduce.push_back(s[s_index]); - s_index++; - } + if (cut != g_device_manager->DeviceNum()) { + s_Squeeze.clear(); } - return s_Reduce; + + return s_Squeeze; } std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, const size_t iter_ops) { std::vector s; + if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || + ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE) { + return s; + } + bool found = false; - for (size_t i = 0; i < (size_t)input_tensor_names.size(); i++) { - for (size_t j = 1; j < (size_t)input_tensor_names[i].size(); j++) { - if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0]) { - for (size_t k = 0; k < ops[i]->selected_strategy()->GetInputDim()[j - 1].size(); ++k) { - s.push_back(ops[i]->selected_strategy()->GetInputDim()[j - 1][k]); - } + size_t outgoing_op_index = SIZE_MAX; + size_t iter_op_inputs = SIZE_MAX; + for (size_t i = 0; i < input_tensor_names.size(); i++) { + for (size_t j = 1; j < input_tensor_names[i].size(); j++) { + if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] && + ops[i]->selected_strategy()->GetInputNumber() != 0) { + outgoing_op_index = i; + iter_op_inputs = j - 1; found = true; break; } } - if (found) break; + if (found) { + break; + } + } + + if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) { + for (size_t k = 0; k < ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs].size(); ++k) { + s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); + } } return s; } @@ -589,23 +560,70 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, const std::shared_ptr> no_stra_op_list) { - MS_EXCEPTION_IF_NULL(no_stra_op_list); - for (int iter_list = no_stra_op_list->size() - 1; iter_list >= 0; iter_list--) { - auto iter_ops = no_stra_op_list->at(iter_list); + if (no_stra_op_list->size() == 0) { + return; + } + std::vector no_stra_op_list_bis; + + for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { + auto iter_ops = no_stra_op_list->at(iter_list - 1); std::vector> stra; std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); + + if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); + } + if (s.size() != 0) { + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + } else { + no_stra_op_list_bis.push_back(iter_ops); + } + + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } + + no_stra_op_list->clear(); + for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { + no_stra_op_list->push_back(no_stra_op_list_bis[i]); + } +} + +void GenerateRemainingOperatorStrategy(const std::shared_ptr graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> index_list, + const std::shared_ptr> no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + + for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { + auto iter_ops = no_stra_op_list->at(iter_list - 1); + std::vector> stra; + std::vector s; + size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); + if (incoming_op_index != SIZE_MAX) { + auto iter_graph = index_list->at(incoming_op_index); + if (iter_graph != SIZE_MAX) { + s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); + } else { + s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index); + } + } + if (s.size() == 0) { - for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) { + size_t max_dim_num = 0; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) { + max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); + } + } + for (size_t i = 0; i < max_dim_num; i++) { s.push_back(1); } } - if (ops[iter_ops]->type() == SQUEEZE) { - s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); - } - if (ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MAX || - ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_MEAN) { - s = ModifyStrategyIfReduceOutgoing(ops, iter_ops, s); - } + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); StrategyPtr sp = std::make_shared(0, stra); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index db275dda1062ca4419b82e850424bb7e65770d35..adde07da8d8221eebb8628fd9b46fe637a7900ea 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -34,37 +34,38 @@ void GenerateStrategy(std::shared_ptr graph, const std::vector> PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareVirtualDataset(const std::vector> &ops, - const size_t iter_ops); -std::vector> PrepareScalarInputOperator(const std::vector> &ops, - const size_t iter_ops, std::vector s); +std::vector> PreparePReLU(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareBiasAdd(std::vector s); std::vector> PrepareOneHot(std::vector s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> MakeDataParallelStrategy(const std::vector> &ops, - const size_t iter_ops); +std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); std::vector> PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, const std::vector> &ops, const std::shared_ptr> index_list); -int FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, const size_t iter_ops); +size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, + const size_t iter_ops); std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr graph, const std::vector> &ops, const size_t iter_ops, const size_t iter_graph); std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const int incoming_op_index); + const size_t incoming_op_index); std::vector GetAxisList(const std::vector> &ops, const int iter_ops); std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const int incoming_op_index, std::vector s); + const size_t incoming_op_index, std::vector s); std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const int incoming_op_index, std::vector s); + const size_t incoming_op_index, std::vector s); std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const int incoming_op_index, const size_t iter_ops, - const std::shared_ptr> no_stra_op_list); + const size_t iter_ops, const size_t incoming_op_index); std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, std::vector s); void GenerateEliminatedOperatorStrategyForward(std::shared_ptr graph, @@ -75,14 +76,17 @@ void GenerateEliminatedOperatorStrategyForward(std::shared_ptr graph, const std::shared_ptr> no_stra_op_list); std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, std::vector s); -std::vector ModifyStrategyIfReduceOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s); std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, const size_t iter_ops); void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, const std::vector> &input_tensor_names, const std::shared_ptr> no_stra_op_list); +void GenerateRemainingOperatorStrategy(const std::shared_ptr graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> index_list, + const std::shared_ptr> no_stra_op_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index a7bc1ae86f0c1d9e546fd732e2271417c1b4a754..9fcb6e5f69a07a80b4b4d5c7c7ebb857812de719 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -47,7 +47,8 @@ enum OperatorType { kRecDiv, kRecSqueeze, kRecCast, - kRecReduce + kRecReduce, + kRecPReLU }; enum InfoType { kApplication, kConstant }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 823b1dca08771aafa051f033c52150b2c8e83a25..add0f5553e8f51b22b0dcd713ec557456e47854e 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -199,7 +199,7 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, - OperatorType::kRecCast}; + OperatorType::kRecCast, OperatorType::kRecReshape}; for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { auto type = graph->nodes[node_index].apply.op_type; if (type_list.find(type) != type_list.end()) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index e6398b955621df60120272679b57b357c011f380..34df09cb99740945babd52ff23c291e08f0dc7cc 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -55,7 +55,8 @@ const std::map DictOpType{ {"HSigmoid", OperatorType::kRecReLU}, {GELU, OperatorType::kRecReLU}, {TANH, OperatorType::kRecReLU}, - {PRELU, OperatorType::kRecReLU}, + + {PRELU, OperatorType::kRecPReLU}, {TENSOR_ADD, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp}, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index ac8e52eed67c035f85b1813fd2c74a231f6d98df..c61da7f16fc3f29f58cc6ce37633f61dd5f1b6fe 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -83,7 +83,7 @@ double GetWeights(const Graph::NodeType &node) { auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecUnkownType) { + } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU) { // For unknown type return 0.0; } else { @@ -177,7 +177,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node, auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecUnkownType) { + } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU) { // For unknown type StrategyRec default_strategy; return default_strategy; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 3811efdd6abfcebc643ac80d108e0b9dd071bd94..d6c8a0751f00d405c65c90ade417f855d474b342 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -464,6 +464,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); if (!IsAutoParallelCareNode(cnode)) { + // Needed by rec_parser + PrimitivePtr prim = GetValueNode(prim_anf_node); + if (prim->name() == TUPLE_GETITEM) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId())); + } continue; } PrimitivePtr prim = GetValueNode(prim_anf_node); @@ -522,6 +527,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); if (!IsAutoParallelCareNode(cnode)) { + // Needed by rec_parser + PrimitivePtr prim = GetValueNode(prim_anf_node); + if (prim->name() == TUPLE_GETITEM) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId())); + } continue; } PrimitivePtr prim = GetValueNode(prim_anf_node); @@ -1153,6 +1163,7 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const MS_LOG(ERROR) << "Constructing nodes for cost graph failed."; return FAILED; } + auto ops = entire_costgraph->GetOperators(); std::vector> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();