diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 8f0a79d0be2503f3eda4d4cab3d0621dcb112770..8b92e18cd853e78d389a67014a8acba21837f7a5 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -211,13 +211,14 @@ struct ContractEliminationDecision : public Decision { */ struct TriangleEliminationDecision : public Decision { TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, - StrategyPtr left_stra, CostPtr l_node_cost) + StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) : eliminated_op_strategy_(std::move(elimi_stra)), eliminated_op_cost_(std::move(elimi_op_cost)), left_edge_cost_(std::move(l_edge_cost)), right_edge_cost_(std::move(r_edge_cost)), left_node_strategy_(std::move(left_stra)), - left_node_cost_(std::move(l_node_cost)) { + left_node_cost_(std::move(l_node_cost)), + right_node_strategy_(std::move(right_stra)) { type_ = DecisionType::TRIANGLE_ELIMINATION; } @@ -227,6 +228,7 @@ struct TriangleEliminationDecision : public Decision { CostPtr right_edge_cost_; StrategyPtr left_node_strategy_; CostPtr left_node_cost_; + StrategyPtr right_node_strategy_; MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc index 8d439f152283056994da219ccb71f2159d1212ff..72451fab57cafe4e5b4811f19c762a40b1660698 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc @@ -85,7 +85,9 @@ Status GetStrategy(const CostGraphPtr &graph) { right_edge = tmp; } auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); - auto elimi = std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge); + auto right_node = l_r_edge->next_operator(); + auto elimi = + std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); eliminations.emplace_back(std::move(elimi)); } auto star_center = graph->CheckStarElimination(); @@ -181,6 +183,7 @@ Status RecoverStrategy(std::vector eliminations) { auto left_edge = elimination->left_edge_; auto eliminated_node = elimination->eliminated_node_; auto right_edge = elimination->right_edge_; + auto right_node = elimination->right_node_; auto decision = left_node->selected_cost()->decision_ptr_->cast(); eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); @@ -188,6 +191,7 @@ Status RecoverStrategy(std::vector eliminations) { right_edge->set_selected_cost(decision->right_edge_cost_); // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); + right_node->CheckSelectedStrategy(decision->right_node_strategy_); MS_LOG(INFO) << "Recover triangleElimination succeeded."; } else if ((*rit)->isa()) { auto elimination = (*rit)->cast(); @@ -206,6 +210,9 @@ Status RecoverStrategy(std::vector eliminations) { MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); + for (size_t k = 1; k < succ_nodes.size(); ++k) { + succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); + } MS_LOG(INFO) << "Recover starElimination succeeded."; } else { MS_LOG(ERROR) << "Unknown Elimination type."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h index efedba7d105b914b6efd83171c7506dcd31fe3a5..e3fbfba5a77f260c2ce2b99165785e525ed7c561 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h @@ -102,17 +102,20 @@ struct ContractElimination : public Elimination { // Triangle Elimination struct TriangleElimination : public Elimination { - TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge) + TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, + OperatorInfoPtr r_node) : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), eliminated_node_(std::move(elim_node)), left_edge_(std::move(l_edge)), left_node_(std::move(l_node)), - right_edge_(std::move(r_edge)) {} + right_edge_(std::move(r_edge)), + right_node_(std::move(r_node)) {} OperatorInfoPtr eliminated_node_; EdgePtr left_edge_; OperatorInfoPtr left_node_; EdgePtr right_edge_; + OperatorInfoPtr right_node_; MS_DECLARE_PARENT(TriangleElimination, Elimination); }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 5419ab69c99e993e02cf342ca523f1eda138553e..043d3a9e23bd2580de86a43c43f50d781e86b5d1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -1111,8 +1111,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; - auto decision = std::make_shared(elimi_op_stra, elimi_op_cost, left_edge_cost, - right_edge_cost, left_op_stra, left_node_cost); + auto decision = std::make_shared( + elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); new_cost->communication_without_parameter_ = new_commu_without; new_cost->communication_with_partial_para_ = diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 4a680a7c362321a4874f9511a8fea6465dbeb5b1..c0cc187e02954aaf489fe75381d600eccba6650a 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -546,10 +546,14 @@ std::vector> OperatorInfo::GetAliveSuccEdges() { for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { ret.push_back(edge); + } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) { + // CAST is ordered in front of L2NORMALIZE + ret.push_back(edge); } } for (auto &edge : succ_edges_) { - if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) { + if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) && + (edge->next_operator()->name().find(CAST) == std::string::npos)) { ret.push_back(edge); } } @@ -1279,10 +1283,18 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra CheckGlobalDeviceManager(); auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { - cost->computation_cost_ -= 1.0; - cost->communication_cost_ -= 1.0; - cost->communication_with_partial_para_ -= 1.0; - cost->communication_without_parameter_ -= 1.0; + if (cost->computation_cost_ > 1.0) { + cost->computation_cost_ -= 1.0; + } + if (cost->communication_cost_ > 1.0) { + cost->communication_cost_ -= 1.0; + } + if (cost->communication_with_partial_para_ > 1.0) { + cost->communication_with_partial_para_ -= 1.0; + } + if (cost->communication_without_parameter_ > 1.0) { + cost->communication_without_parameter_ -= 1.0; + } } } } @@ -1290,5 +1302,15 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra double OperatorInfo::GetForwardMemoryCostFromCNode() { return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); } + +void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { + MS_EXCEPTION_IF_NULL(s_strategy); + if (!s_strategy->IsEqual(selected_strategy_)) { + MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:"; + PrintStrategy(selected_strategy_); + MS_LOG(INFO) << "The minimal strategy:"; + PrintStrategy(s_strategy); + } +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 30d618bd104fa1cbfc05579f9a7e8c1cb1d8f420..6888a88f7202f5a8ba7f70a3f2abe29e4131c9df 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -138,6 +138,7 @@ class OperatorInfo { } StrategyPtr selected_strategy() const { return selected_strategy_; } CostPtr selected_cost() const { return selected_cost_; } + void CheckSelectedStrategy(const StrategyPtr &); Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } void set_input_value(const std::vector &input_value) { input_value_ = input_value; } const std::vector &input_value() const { return input_value_; } diff --git a/mindspore/ccsrc/parallel/strategy.h b/mindspore/ccsrc/parallel/strategy.h index fce99305a5e05a12879843782ae9d0b8df5120c7..bc62dd5308747f4824adbc78904f63021142a366 100644 --- a/mindspore/ccsrc/parallel/strategy.h +++ b/mindspore/ccsrc/parallel/strategy.h @@ -48,6 +48,16 @@ class Strategy { } void ResetInputs(const std::vector &input) { inputs_ = input; } + bool IsEqual(const StrategyPtr &another_stra) { + if (another_stra == nullptr) { + return false; + } + if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) { + return false; + } + return true; + } + private: const int32_t stage_; diff --git a/tests/ut/cpp/parallel/strategy_test.cc b/tests/ut/cpp/parallel/strategy_test.cc index 21988e3517b18c09acd0712fbe97000ba4c662d0..9a2f92f018c8db83206e3f457bbf0cee73d9d9f4 100644 --- a/tests/ut/cpp/parallel/strategy_test.cc +++ b/tests/ut/cpp/parallel/strategy_test.cc @@ -64,5 +64,23 @@ TEST_F(TestStrategy, GetInputDim) { ASSERT_EQ(inputs, inputs_test); } +TEST_F(TestStrategy, IsEqual) { + int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0; + std::vector dimension1 = {8, 1}; + std::vector dimension2 = {1, 8}; + std::vector> inputs1 = {dimension1}; + std::vector> inputs2 = {dimension1}; + std::vector> inputs3 = {dimension2}; + std::vector> inputs4 = {dimension1, dimension2}; + + StrategyPtr stra1 = std::make_shared(stage1, inputs1); + StrategyPtr stra2 = std::make_shared(stage2, inputs2); + StrategyPtr stra3 = std::make_shared(stage3, inputs3); + StrategyPtr stra4 = std::make_shared(stage4, inputs4); + + ASSERT_EQ(stra1->IsEqual(stra2), true); + ASSERT_EQ(stra1->IsEqual(stra3), false); + ASSERT_EQ(stra1->IsEqual(stra4), false); +} } // namespace parallel } // namespace mindspore