From 8f04adf1c321367333a2baca8cc229dd9ac0758f Mon Sep 17 00:00:00 2001 From: hongxing Date: Thu, 14 May 2020 23:28:12 +0200 Subject: [PATCH] feature : eliminate graph --- .../rec_core/rec_generate_strategy.cc | 303 ++++++++++++------ .../rec_core/rec_generate_strategy.h | 45 ++- .../auto_parallel/rec_core/rec_graph.h | 3 +- .../auto_parallel/rec_core/rec_parse_graph.cc | 67 ++++ .../auto_parallel/rec_core/rec_parse_graph.h | 15 +- .../ccsrc/parallel/step_auto_parallel.cc | 9 +- 6 files changed, 331 insertions(+), 111 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index c32b7d5b1..124b64fb9 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -28,52 +28,56 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops) { +void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops, + const std::shared_ptr>> eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> index_list) { MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(eli_list); + MS_EXCEPTION_IF_NULL(index_list); + GeneratePartitionedOperatorStrategy(graph, ops, index_list); + std::shared_ptr> no_stra_op_list(new std::vector); + GenerateEliminatedOperatorStrategyForward(graph, ops, eli_list, input_tensor_names, index_list, no_stra_op_list); + GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); +} - for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { - std::vector> stra; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs)); - } - // OneHot's scalar parameters were removed by entire_costgraph, we had to complete them. - if (ops[iter_ops]->type() == ONEHOT) { - std::vector s_Onehot = {}; - stra.push_back(s_Onehot); - stra.push_back(s_Onehot); +std::vector> PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + std::vector s; + auto attrs = ops[iter_ops]->attrs(); + bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); + bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); + if (transpose_a && (iter_op_inputs == 0)) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + } else if (transpose_b && (iter_op_inputs == 1)) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + } else { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); } - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + strategies.push_back(s); } + return strategies; } -std::vector PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, const size_t iter_nodes, - const size_t iter_op_inputs) { - std::vector s; - auto attrs = ops[iter_nodes]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - if (transpose_a && (iter_op_inputs == 0)) { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); - } else if (transpose_b && (iter_op_inputs == 1)) { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); - } else { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } - return s; +std::vector> PrepareVirtualDataset(const std::vector> &ops, + const size_t iter_ops) { + std::vector> strategies = MakeDataParallelStrategy(ops, iter_ops); + strategies[1][0] = strategies[0][0]; + return strategies; } -// std::vector> PrepareVirtualDataset(const std::vector> &ops, -// const size_t iter_ops) { -// std::vector> strategies = MakeDataParallelStrategy(ops, iter_ops); -// strategies[1][0] = strategies[0][0]; -// return strategies; -// } - std::vector> PrepareBiasAdd(const std::vector> &ops, const size_t iter_ops, std::vector s) { std::vector> strategies; @@ -99,9 +103,9 @@ std::vector> PrepareOneHot(std::vector s) { return strategies; } -std::vector MakeRecSearchStrategy(const std::vector> &ops, - const std::shared_ptr &graph, const size_t iter_ops, - const size_t iter_op_inputs) { +std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -111,35 +115,46 @@ std::vector MakeRecSearchStrategy(const std::vectorstrategy(); - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { - MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; - } - - // size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size(); - size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - - std::vector s = {}; - if (output_size == 4) { - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n)); - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c)); - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 2) { - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 1) { - s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 0) { - return s; - } else { - MS_LOG(ERROR) << "Tensor's output size is unexcepted."; - } + std::vector> strategies; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } - return s; + // size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size(); + size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + + std::vector s; + if (output_size == 4) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 2) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 1) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 0) { + s = {}; + } else { + MS_LOG(ERROR) << "Tensor's output size is unexcepted."; + } + + strategies.push_back(s); + } + return strategies; } -std::vector MakeDataParallelStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t iter_op_inputs) { +std::vector> MakeDataParallelStrategy(const std::vector> &ops, + const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -149,28 +164,32 @@ std::vector MakeDataParallelStrategy(const std::vectorstrategy(); - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { - MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; - } + std::vector> strategies; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } - std::vector s; - size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - for (size_t dim = 0; dim < input_size; dim++) { - if (dim == 0 && input_size == 4) { - size_t max_device_num = g_device_manager->DeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; - s.push_back(std::min(max_device_num, target_tensor_batch)); - } else { - s.push_back(1); + std::vector s; + size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + if (dim == 0 && input_size == 4) { + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; + s.push_back(std::min(max_device_num, target_tensor_batch)); + } else { + s.push_back(1); + } } - } - return s; + strategies.push_back(s); + } + return strategies; } -std::vector PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, const size_t iter_ops, - const size_t iter_op_inputs) { +std::vector> PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -179,19 +198,35 @@ std::vector PrepareStrategy(const std::shared_ptr &graph, } auto type = ops[iter_ops]->type(); + if (type == VIRTUAL_DATA_SET) { + return PrepareVirtualDataset(ops, iter_ops); + } auto idx = DictOpType.find(type); if (idx == DictOpType.end()) { - return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); + return MakeDataParallelStrategy(ops, iter_ops); } if (type == MATMUL) { - return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs); + return PrepareMatMul(graph, ops, iter_graph, iter_ops); } else if (type == RESHAPE) { - return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); - } else if (type == DIV || type == SUB || type == MUL) { - return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); + return MakeDataParallelStrategy(ops, iter_ops); } else { - return MakeRecSearchStrategy(ops, graph, iter_ops, iter_op_inputs); + return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + } +} + +void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, + const std::vector> &ops, + const std::shared_ptr> index_list) { + for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { + std::vector> strategies; + size_t iter_graph = index_list->at(iter_ops); + if (iter_graph == SIZE_MAX) { + continue; + } + strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); + StrategyPtr sp = std::make_shared(0, strategies); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); } } @@ -353,6 +388,25 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const int incoming_op_index, const size_t iter_ops, + const std::shared_ptr> no_stra_op_list) { + std::vector s; + s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); + if (s.size() != 0) { + if (ops[incoming_op_index]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); + } + if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX || + ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { + s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); + } + } else { + no_stra_op_list->push_back(iter_ops); + } + return s; +} + std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, std::vector s) { std::vector s_empty = {}; @@ -389,6 +443,33 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect return stra; } +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr graph, + const std::vector> &ops, + const std::shared_ptr>> eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> index_list, + const std::shared_ptr> no_stra_op_list) { + for (int eli_index = eli_list->size() - 1; eli_index >= 0; eli_index--) { + size_t iter_ops = eli_list->at(eli_index)[0]; + std::vector> stra; + std::vector s; + int incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); + if (incoming_op_index != -1) { + auto iter_graph = index_list->at(incoming_op_index); + if (iter_graph != SIZE_MAX) { + s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); + } else { + s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list); + } + } else { + no_stra_op_list->push_back(iter_ops); + } + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } +} + std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, std::vector s) { std::vector s_Squeeze; @@ -427,5 +508,47 @@ std::vector ModifyStrategyIfReduceOutgoing(const std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops) { + std::vector s; + bool found = false; + for (size_t i = 0; i < (size_t)input_tensor_names.size(); i++) { + for (size_t j = 1; j < (size_t)input_tensor_names[i].size(); j++) { + if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0]) { + for (size_t k = 0; k < ops[i]->selected_strategy()->GetInputDim()[j - 1].size(); ++k) { + s.push_back(ops[i]->selected_strategy()->GetInputDim()[j - 1][k]); + } + found = true; + break; + } + } + if (found) break; + } + return s; +} + +void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> no_stra_op_list) { + MS_EXCEPTION_IF_NULL(no_stra_op_list); + for (int iter_list = no_stra_op_list->size() - 1; iter_list >= 0; iter_list--) { + auto iter_ops = no_stra_op_list->at(iter_list); + std::vector> stra; + std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); + if (ops[iter_ops]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); + } + if (ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MAX || + ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_MEAN) { + s = ModifyStrategyIfReduceOutgoing(ops, iter_ops, s); + } + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } +} + } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 56985700b..b6a96d5f5 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -27,23 +27,29 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops); -std::vector PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, const size_t iter_nodes, - const size_t iter_op_inputs); +void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops, + const std::shared_ptr>> eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> index_list); +std::vector> PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); std::vector> PrepareVirtualDataset(const std::vector> &ops, const size_t iter_ops); std::vector> PrepareBiasAdd(const std::vector> &ops, const size_t iter_ops, std::vector s); std::vector> PrepareOneHot(std::vector s); -std::vector MakeRecSearchStrategy(const std::vector> &ops, - const std::shared_ptr &graph, const size_t iter_ops, - const size_t iter_op_inputs); -std::vector MakeDataParallelStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t iter_op_inputs); -std::vector PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, const size_t iter_ops, - const size_t iter_op_inputs); +std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> MakeDataParallelStrategy(const std::vector> &ops, + const size_t iter_ops); +std::vector> PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, + const std::vector> &ops, + const std::shared_ptr> index_list); int FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, const size_t iter_ops); std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr graph, const std::vector> &ops, @@ -56,12 +62,27 @@ std::vector ModifyStrategyIfSqueezeIncoming(const std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, const int incoming_op_index, std::vector s); +std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const int incoming_op_index, const size_t iter_ops, + const std::shared_ptr> no_stra_op_list); std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, std::vector s); +void GenerateEliminatedOperatorStrategyForward(std::shared_ptr graph, + const std::vector> &ops, + const std::shared_ptr>> eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> index_list, + const std::shared_ptr> no_stra_op_list); std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, std::vector s); std::vector ModifyStrategyIfReduceOutgoing(const std::vector> &ops, const size_t iter_ops, std::vector s); +std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops); +void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> no_stra_op_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index bc4f39436..a7bc1ae86 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -46,7 +46,8 @@ enum OperatorType { kRecMul, kRecDiv, kRecSqueeze, - kRecCast + kRecCast, + kRecReduce }; enum InfoType { kApplication, kConstant }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index ada22fef9..99df43ca2 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/value.h" #include "parallel/auto_parallel/rec_core/rec_graph.h" @@ -161,5 +162,71 @@ size_t GetIndexInInputTensorNames(const std::vector> &i MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; return SIZE_MAX; } + +void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, + const std::shared_ptr>> eli_list) { + std::vector eli; + eli.push_back(node_index); + for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { + eli.push_back(graph->nodes[node_index].node_out[i]); + } + eli_list->push_back(eli); + for (auto input_index : graph->nodes[node_index].node_in) { + auto it = find(graph->nodes[input_index].node_out.begin(), graph->nodes[input_index].node_out.end(), node_index); + if (it != graph->nodes[input_index].node_out.end()) { + graph->nodes[input_index].node_out.erase(it); + for (auto output_index : graph->nodes[node_index].node_out) { + graph->nodes[input_index].node_out.push_back(output_index); + } + } + } + for (auto output_index : graph->nodes[node_index].node_out) { + auto it = find(graph->nodes[output_index].node_in.begin(), graph->nodes[output_index].node_in.end(), node_index); + if (it != graph->nodes[output_index].node_in.end()) { + graph->nodes[output_index].node_in.erase(it); + for (auto input_index : graph->nodes[node_index].node_in) { + graph->nodes[output_index].node_in.push_back(input_index); + } + } + } +} + +std::shared_ptr EliminateGraph(const std::shared_ptr graph, + const std::shared_ptr>> eli_list, + const std::shared_ptr> index_list) { + MS_EXCEPTION_IF_NULL(graph); + const std::set type_list = { + OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, + OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, + OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, + OperatorType::kRecCast}; + for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { + auto type = graph->nodes[node_index].apply.op_type; + if (type_list.find(type) != type_list.end()) { + Eliminate_Aux(node_index, graph, eli_list); + } + } + index_list->reserve(graph->nodes.size()); + for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { + index_list->push_back(i); + } + for (size_t i = 0; i < (size_t)eli_list->size(); i++) { + if (eli_list->at(i)[0] >= index_list->size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + index_list->at(eli_list->at(i)[0]) = SIZE_MAX; + for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { + index_list->at(j)--; + } + } + std::shared_ptr new_graph(new Graph); + for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { + if (index_list->at(i) > SIZE_MAX / 2) { + continue; + } + new_graph->nodes.push_back(graph->nodes[i]); + } + return new_graph; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index d0feccb9b..6af1deea9 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -50,10 +50,10 @@ const std::map DictOpType{ {DIV, OperatorType::kRecElmWiseOp}, {SQUEEZE, OperatorType::kRecSqueeze}, {CAST, OperatorType::kRecCast}, - {REDUCE_SUM, OperatorType::kRecCast}, - {REDUCE_MAX, OperatorType::kRecCast}, - {REDUCE_MIN, OperatorType::kRecCast}, - {REDUCE_MEAN, OperatorType::kRecCast}}; + {REDUCE_SUM, OperatorType::kRecReduce}, + {REDUCE_MAX, OperatorType::kRecReduce}, + {REDUCE_MIN, OperatorType::kRecReduce}, + {REDUCE_MEAN, OperatorType::kRecReduce}}; const TensorParam MakeTensor(int n, int c, int h, int w); @@ -72,6 +72,13 @@ void MakeEdge(const std::vector> &input_tensor_names, s size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, const std::string &input_name); + +void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, + const std::shared_ptr>> eli_list); + +std::shared_ptr EliminateGraph(const std::shared_ptr graph, + const std::shared_ptr>> eli_list, + const std::shared_ptr> index_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 7a3d5a38f..f7a18c8b5 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -1158,11 +1158,12 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) { input_tensor_names = RecInputTensorNames(it++, input_tensor_names); } - - std::shared_ptr> ops_nodes_list(new std::vector); - std::shared_ptr graph = ParseGraph(ops, input_tensor_names); + std::shared_ptr>> eli_list(new std::vector>); + std::shared_ptr> index_list(new std::vector); + graph = EliminateGraph(graph, eli_list, index_list); + size_t num_device = g_device_manager->DeviceNum(); double device_memory = entire_costgraph->GetDeviceMemory(); if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { @@ -1172,7 +1173,7 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const return FAILED; } - GenerateStrategy(graph, ops); + GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { MS_LOG(INFO) << "Init selected strategy succeeded."; -- GitLab