diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc index adbc2005b8ab84bb6802e77eebad900080303fbc..9fb79ceee42762a69d033e4b1df5803db7a2b3eb 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc @@ -703,5 +703,48 @@ StrategyRec CostBatchParallel::ChoseStr(const std::vector &cost_op, Stra } return str; } + +// Chose strategy for CostSoftmaxCrossEntropyWithLogits +StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.inputTensor[1].str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; + } + return str; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h index 55ed2bdf17a1cedb67283be089a709483f02c275..fb4fc27164cfb55f8cc7e8a5146529029763e3a3 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h @@ -222,6 +222,12 @@ class CostBatchParallel { class CostBatchNorm : public CostBatchParallel {}; class CostOneHot : public CostBatchParallel {}; +class CostPRelu : public CostBatchParallel {}; +class CostSoftmax : public CostBatchParallel {}; + +class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index ced24882365f6f9cdbb4332ac832413c5d22fc6c..19e07aae0252a75714b9664209e3dd6eebcf5ae7 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -127,14 +127,6 @@ std::vector> PrepareMatMul(const std::shared_ptr &gr return strategies; } -std::vector> PreparePReLU(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); - strategies[1][0] = 1; - return strategies; -} - std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { std::vector> strategies; strategies.push_back(*s); @@ -164,6 +156,32 @@ std::vector> PrepareGatherV2(const std::shared_ptr> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + int32_t axis = 0; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; + } + } + + int32_t axis_index = axis; + if (axis < 0) { + size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + axis_index = static_cast(input_dim) + axis; + } + + s[IntToSize(axis_index)] = 1; + + std::vector> strategies; + strategies.push_back(s); + return strategies; +} + std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { @@ -279,13 +297,8 @@ std::vector> PrepareStrategy(const std::shared_ptr & if (type == MATMUL) { return PrepareMatMul(graph, ops, iter_graph, iter_ops); - } else if (type == PRELU) { - return PreparePReLU(graph, ops, iter_graph, iter_ops); } else if (type == ONEHOT) { return PrepareOneHot(graph, ops, iter_graph, iter_ops); - } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS || - type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { - return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); } @@ -510,6 +523,9 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect if (ops[iter_ops]->type() == GATHERV2) { return PrepareGatherV2(s_ptr); } + if (ops[iter_ops]->type() == L2_NORMALIZE) { + return PrepareL2Normalize(ops, iter_ops, basic_stra); + } for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 57e3ee7f71f8c3a534242f4e292e1a263a9a26fc..c9604b449f6dcb3eff64ec44d7229dd87072daaa 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -34,14 +34,13 @@ void GenerateStrategy(std::shared_ptr graph, const std::vector> PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> PreparePReLU(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); std::vector> PrepareBiasAdd(const std::shared_ptr> &s); std::vector> PrepareOneHot(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); std::vector> PrepareGatherV2(const std::shared_ptr> &s); +std::vector> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index d578bd82ef1705d8a8118a6ec0fbfeaa3aefc488..647b857e1618c538758fb075f2f1fc3cdeed9414 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -38,6 +38,7 @@ enum OperatorType { kRecBiasAdd, kRecSoftmax, kRecSparseSoftmaxCrossEntropyWithLogits, + kRecSoftmaxCrossEntropyWithLogits, kRecOneHot, kRecLog, kRecExp, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 07a2eb7c496b7a9038f09f0350fdb59a82d09b34..3e4eafe0a4c6f04704674da93309e6fd7ada21e9 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -250,12 +250,22 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, new_graph->nodes.push_back(graph->nodes[i]); auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; - for (size_t j = 0; j < node_in->size(); j++) { - node_in->at(j) = index_list->at(node_in->at(j)); + for (size_t j = node_in->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_in->erase(node_in->begin() + j - 1); + } else { + node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); + } } auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; - for (size_t j = 0; j < node_out->size(); j++) { - node_out->at(j) = index_list->at(node_out->at(j)); + for (size_t j = node_out->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_out->erase(node_out->begin() + j - 1); + } else { + node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); + } } } return new_graph; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index f39546dffc371e8eb765775568c5e28fce95a4dd..1be8e4c79631a9cb9c5fec3548b8504c6fbeb1a7 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -67,7 +67,7 @@ const std::map DictOpType{ {REAL_DIV, OperatorType::kRecElmWiseOp}, {SOFTMAX, OperatorType::kRecSoftmax}, {LOG_SOFTMAX, OperatorType::kRecSoftmax}, - {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmax}, + {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, {SQRT, OperatorType::kRecElmWiseOp}, {NEG, OperatorType::kRecElmWiseOp}, {POW, OperatorType::kRecElmWiseOp}, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index b5b823dc27bb53c2a52763e706feaed4c6b23f16..0f6e736d5288b2c7511c306eb1f845de269feb94 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -76,15 +76,16 @@ double GetWeights(const Graph::NodeType &node) { auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot) { + } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || + op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || + op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || + op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { // For BatchParallel op auto cost_ptr = std::make_shared(); return cost_ptr->GetMaxCostIn(); - } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU || - op.op_type == OperatorType::kRecSoftmax || - op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For unprocessed type + } else if (op.op_type == OperatorType::kRecUnkownType) { + // For Unkown type return 0.0; } else { MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; @@ -170,14 +171,18 @@ StrategyRec PartitionNode(const Graph::NodeType &node, auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot) { + } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || + node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || + node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { // For BatchParallel type auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU || - node.apply.op_type == OperatorType::kRecSoftmax || - node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For unprocessed type + } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For SoftmaxCrossEntropyWithLogits type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecUnkownType) { + // For Unkown type StrategyRec default_strategy; return default_strategy; } else {