diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 229f0fbf5e5f1f8d7c08119cce1ce07e34abd293..9e9003848b5ca234d6c0b726ca6f839926f59084 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -207,15 +207,13 @@ struct ContractEliminationDecision : public Decision { */ struct TriangleEliminationDecision : public Decision { TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, - StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost) + StrategyPtr left_stra, CostPtr l_node_cost) : eliminated_op_strategy_(std::move(elimi_stra)), eliminated_op_cost_(std::move(elimi_op_cost)), left_edge_cost_(std::move(l_edge_cost)), right_edge_cost_(std::move(r_edge_cost)), left_node_strategy_(std::move(left_stra)), - left_node_cost_(std::move(l_node_cost)), - right_node_strategy_(std::move(right_stra)), - right_node_cost_(std::move(r_node_cost)) { + left_node_cost_(std::move(l_node_cost)) { type_ = DecisionType::TRIANGLE_ELIMINATION; } @@ -225,8 +223,6 @@ struct TriangleEliminationDecision : public Decision { CostPtr right_edge_cost_; StrategyPtr left_node_strategy_; CostPtr left_node_cost_; - StrategyPtr right_node_strategy_; - CostPtr right_node_cost_; MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc index 060caa4cca4fb6c0fe04f6d06c91766e721a14fa..dd21096fcc430f99c83b980b5345e6a9eae4c46d 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc @@ -76,7 +76,6 @@ Status GetStrategy(const CostGraphPtr& graph) { auto l_r_edge = triangle_pair.second; auto left_node = l_r_edge->prev_operator(); - auto right_node = l_r_edge->next_operator(); auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; MS_EXCEPTION_IF_NULL(left_edge); @@ -86,8 +85,7 @@ Status GetStrategy(const CostGraphPtr& graph) { right_edge = tmp; } auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); - auto elimi = - std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); + auto elimi = std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge); eliminations.emplace_back(std::move(elimi)); } auto star_center = graph->CheckStarElimination(); @@ -183,14 +181,13 @@ Status RecoverStrategy(std::vector eliminations) { auto left_edge = elimination->left_edge_; auto eliminated_node = elimination->eliminated_node_; auto right_edge = elimination->right_edge_; - auto right_node = elimination->right_node_; auto decision = left_node->selected_cost()->decision_ptr_->cast(); eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); left_edge->set_selected_cost(decision->left_edge_cost_); right_edge->set_selected_cost(decision->right_edge_cost_); + // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); - right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); MS_LOG(INFO) << "Recover triangleElimination succeeded."; } else if ((*rit)->isa()) { auto elimination = (*rit)->cast(); @@ -204,9 +201,11 @@ Status RecoverStrategy(std::vector eliminations) { for (size_t i = 0; i < succ_edges.size(); ++i) { succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); } - for (size_t j = 0; j < succ_nodes.size(); ++j) { - succ_nodes[j]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[j], decision->succ_ops_cost_list_[j]); - } + MS_EXCEPTION_IF_NULL(succ_nodes[0]); + MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); + MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); + // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. + succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); MS_LOG(INFO) << "Recover starElimination succeeded."; } else { MS_LOG(ERROR) << "Unknown Elimination type."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h index 0cb58c49da6992c8880fb2e1e23a27002e196d1a..6d43218e19fc103f710007c013d91cbd3f10a64d 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h @@ -102,20 +102,17 @@ struct ContractElimination : public Elimination { // Triangle Elimination struct TriangleElimination : public Elimination { - TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, - OperatorInfoPtr r_node) + TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge) : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), eliminated_node_(std::move(elim_node)), left_edge_(std::move(l_edge)), left_node_(std::move(l_node)), - right_edge_(std::move(r_edge)), - right_node_(std::move(r_node)) {} + right_edge_(std::move(r_edge)) {} OperatorInfoPtr eliminated_node_; EdgePtr left_edge_; OperatorInfoPtr left_node_; EdgePtr right_edge_; - OperatorInfoPtr right_node_; MS_DECLARE_PARENT(TriangleElimination, Elimination); }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index cbd66f58a6d3c098a79280c3ca4805ea02655457..895646f409fa4346fd998ad76cb9187f63f9ddd5 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -119,6 +119,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co double forward_comm_cost = tensor_redistribution.forward_comm_cost(); double backward_comm_cost = tensor_redistribution.backward_comm_cost(); double computation_cost = tensor_redistribution.computation_cost(); + double mem_cost = tensor_redistribution.memory_cost(); // Now AllGather, ReduceScatter, AlltoAll don't support bool type MS_EXCEPTION_IF_NULL(type); @@ -134,6 +135,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; + (*cost)->memory_with_reuse_ = mem_cost; return Status::SUCCESS; } @@ -158,8 +160,8 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - std::function recursive = - [&](size_t k, double computation, double communication, double communication_without_para) { + std::function recursive = + [&](size_t k, double computation, double memory, double communication, double communication_without_para) { if (k == edges.size()) { auto decision = std::make_shared(selected_cost_list); CostPtr new_cost = std::make_shared(computation, communication); @@ -167,6 +169,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; new_cost->decision_ptr_ = decision; result.push_back(new_cost); return; @@ -174,11 +177,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr for (auto& c : all_cost_list[k]) { MS_EXCEPTION_IF_NULL(c); selected_cost_list[k] = c; - recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_, + recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, + communication + c->communication_cost_, communication_without_para + c->communication_without_parameter_); } }; - recursive(0, 0, 0, 0); + recursive(0, 0.0, 0.0, 0.0, 0.0); SimplifyForDreasingCommunicationWithPartialPara(&result); return result; } @@ -218,6 +222,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr double communication_without_para = left_cost->communication_without_parameter_ + middle_cost->communication_without_parameter_ + right_cost->communication_without_parameter_; + double memory_cost = + left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_; auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); auto cost = std::make_shared(computation, communication, decision); @@ -225,6 +231,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + cost->memory_with_reuse_ = memory_cost; ret_cost_list->emplace_back(std::move(cost)); } } @@ -267,5 +274,24 @@ void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op, MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; } } + +Status Edge::CalculateMemoryCost() { + if (is_output_parameter_involve_ == -1) { + MS_LOG(ERROR) << "is_output_parameter_involve_ is unset."; + return FAILED; + } + if (is_output_parameter_involve_ == 0) { + // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is + // unnecessary to keep them in memory. + for (auto& cost_kv : cost_map_) { + auto& cost_v = cost_kv.second; + if (!cost_v.empty()) { + cost_v[0]->memory_with_reuse_ = 0; + } + } + } + + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h index bd882bb43fd8ba7543666da088cefda8a695724f..f9741257493b4dbfbc2548491e515d6dbe12fdde 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h @@ -133,7 +133,7 @@ class Edge { void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. - Status CalculateMemoryCost() const { return SUCCESS; } + Status CalculateMemoryCost(); private: std::string edge_name_; diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 88a54662d38a43a589b90f0b791995482798bd77..82dd72303908cbdb6a686e78742d2bf5a14904d4 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -248,6 +248,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: MS_EXCEPTION_IF_NULL(cost2); MS_EXCEPTION_IF_NULL(cost3); double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; + double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; double commmunication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; double communication_without_para = cost1->communication_without_parameter_ + @@ -260,6 +261,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para); + cost->memory_with_reuse_ = memory; ret.push_back(cost); } } @@ -288,6 +290,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { new_cost->communication_with_partial_para_ = cost1->communication_without_parameter_ + COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); + new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; ret.push_back(new_cost); } } @@ -297,9 +300,14 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { } CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { - if (cost_list.empty() || cost_list[0]->computation_cost_ >= memory) { - return nullptr; + CostPtrList after_mem_filter; + // Filter out the valid costs + for (auto& a_cost : cost_list) { + if (a_cost->memory_with_reuse_ <= memory) { + after_mem_filter.emplace_back(std::move(a_cost)); + } } + std::function LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { MS_EXCEPTION_IF_NULL(cost_x); if (init == nullptr || cost_x->computation_cost_ < memory) { @@ -308,7 +316,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, return init; }; CostPtr ret = nullptr; - return std::accumulate(cost_list.begin(), cost_list.end(), ret, LocalCompare); + return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); } CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { @@ -318,36 +326,46 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d MS_LOG(ERROR) << "Final cost list is null."; return nullptr; } - CostPtr ret = cost_list[0]; - MS_EXCEPTION_IF_NULL(ret); - if (ret->computation_cost_ >= memory) { - MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->computation_cost_ + CostPtrList after_mem_filter; + double minimum_memory = DBL_MAX; + // Filter out the valid costs. + for (auto& a_cost : cost_list) { + if (a_cost->memory_with_reuse_ <= memory) { + after_mem_filter.emplace_back(std::move(a_cost)); + } else if (a_cost->memory_with_reuse_ < minimum_memory) { + minimum_memory = a_cost->memory_with_reuse_; + } + } + if (after_mem_filter.empty()) { + MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory << ", the memory capacity is: " << memory << "."; return nullptr; } + // Init the returned value with first cost. + CostPtr ret = after_mem_filter[0]; + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; - MS_LOG(INFO) << "minimum: " << minimum << ", computation_cost_: " << ret->computation_cost_ + MS_LOG(INFO) << "Cost 0: " + << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ << ", communication_cost_: " << ret->communication_cost_ << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - for (size_t i = 1; i < cost_list.size(); ++i) { - MS_EXCEPTION_IF_NULL(cost_list[i]); - if (cost_list[i]->computation_cost_ >= memory) { - MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ - << ", is larger than the memory capacity: " << memory << "."; - break; - } - MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ - << ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_ - << ", communication_cost_: " << cost_list[i]->communication_cost_ - << ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << "."; - auto tmp = costmodel_alpha_ * cost_list[i]->computation_cost_ + - costmodel_beta_ * cost_list[i]->communication_with_partial_para_; - MS_LOG(INFO) << "tmp: " << tmp; + MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; + for (size_t i = 1; i < after_mem_filter.size(); ++i) { + MS_EXCEPTION_IF_NULL(after_mem_filter[i]); + MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ + << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ + << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ + << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ + << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ + << "."; + auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + + costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; + MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; if (minimum > tmp) { minimum = tmp; - ret = cost_list[i]; - MS_LOG(INFO) << "selected: " << i; + ret = after_mem_filter[i]; + MS_LOG(INFO) << "Selected: " << i; } } return ret; @@ -356,17 +374,21 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_cost_list, double available_memory) { CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - double minimum = 0.0, total_memory = 0.0; + double minimum = DBL_MAX, total_memory = 0.0; CostPtrList ret(all_cost_list.size(), nullptr); + // Check whether valid costs exist. for (size_t i = 0; i < all_cost_list.size(); ++i) { if (all_cost_list[i][0] == nullptr) { MS_LOG(ERROR) << "The cost list " << i << " is empty."; return ret; } else { - total_memory += all_cost_list[i][0]->computation_cost_; - minimum += costmodel_alpha_ * all_cost_list[i][0]->computation_cost_ + - costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_; - ret[i] = all_cost_list[i][0]; + double memory_i_cost = DBL_MAX; + for (size_t j = 0; j < all_cost_list[i].size(); ++j) { + if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) { + memory_i_cost = all_cost_list[i][j]->memory_with_reuse_; + } + } + total_memory += memory_i_cost; } } if (total_memory >= available_memory) { @@ -381,7 +403,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect double tmp_memory = 0.0, tmp_minimum = 0.0; for (size_t i = 0; i < selected_cost_list.size(); ++i) { MS_EXCEPTION_IF_NULL(selected_cost_list[i]); - tmp_memory += selected_cost_list[i]->computation_cost_; + tmp_memory += selected_cost_list[i]->memory_with_reuse_; tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; } @@ -816,6 +838,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const auto& tar_cost = tar_cost_list[k]; MS_EXCEPTION_IF_NULL(tar_cost); double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; + double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; double communication = op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; double communication_without_para = op_cost->communication_without_parameter_ + @@ -829,6 +852,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; MS_EXCEPTION_IF_NULL(tar_cost_list_new); tar_cost_list_new->emplace_back(std::move(new_cost)); } @@ -894,6 +918,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str MS_EXCEPTION_IF_NULL(tar_cost); double computation = contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; + double memory = + contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; double communication = contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; double communication_without_para = contract_op_cost->communication_without_parameter_ + @@ -906,6 +932,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; tar_cost_list_new->emplace_back(std::move(new_cost)); } } @@ -966,23 +993,22 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, for (auto& left_node_cost : left_node_clist_origin) { MS_EXCEPTION_IF_NULL(left_node_cost); double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + - left_node_cost->computation_cost_ + right_edge_cost->computation_cost_ + - right_op_cost->computation_cost_; + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; + double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ + + left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + - left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ + - right_op_cost->communication_cost_; + left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; double new_commu_without = elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + - left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ + - right_op_cost->communication_without_parameter_; + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; - auto decision = - std::make_shared(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, - left_op_stra, left_node_cost, right_op_stra, right_op_cost); + auto decision = std::make_shared(elimi_op_stra, elimi_op_cost, left_edge_cost, + right_edge_cost, left_op_stra, left_node_cost); auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); new_cost->communication_without_parameter_ = new_commu_without; new_cost->communication_with_partial_para_ = new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); + new_cost->memory_with_reuse_ = new_memory; left_node_clist_new->emplace_back(std::move(new_cost)); } } @@ -1085,14 +1111,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n succ_nodes_costs[0] = first_succ_node_cost; double computation_cost = merged_node_cost->computation_cost_, - commu_cost = merged_node_cost->communication_cost_, + memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, commu_without = merged_node_cost->communication_without_parameter_; for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); - computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; - commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; - commu_without += succ_edges_costs[i]->communication_without_parameter_ + - succ_nodes_costs[i]->communication_without_parameter_; + if (i == 0) { + computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; + memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; + commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; + commu_without += succ_edges_costs[i]->communication_without_parameter_ + + succ_nodes_costs[i]->communication_without_parameter_; + } else { + computation_cost += succ_edges_costs[i]->computation_cost_; + memory_cost += succ_edges_costs[i]->memory_with_reuse_; + commu_cost += succ_edges_costs[i]->communication_cost_; + commu_without += succ_edges_costs[i]->communication_without_parameter_; + } } auto decision = std::make_shared(merged_op_stra, merged_node_cost, succ_edges_costs, @@ -1100,6 +1134,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n auto new_cost = std::make_shared(computation_cost, commu_cost, decision); new_cost->communication_without_parameter_ = commu_without; new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); + new_cost->memory_with_reuse_ = memory_cost; first_succ_node_clist_new->emplace_back(std::move(new_cost)); } } @@ -1259,5 +1294,35 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c } return nullptr; } +Status CostGraph::CorrectOpsMemoryCost() { + for (auto& one_op : ops_) { + if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { + if (one_op->GetAliveSuccEdges().size() > 1) { + // Filter out the case when the TmpIdentity being used by multiple operators + std::map output_count; + for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { + auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); + output_count[output_index]++; + } + for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { + auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); + if (output_count[output_index] <= 1) { + continue; + } + auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator(); + MS_EXCEPTION_IF_NULL(next_op); + auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index(); + if (next_op->CorrectMemoryCost(input_index) != SUCCESS) { + MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name() + << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; + return FAILED; + } + output_count[output_index]--; + } + } + } + } + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index c149534826474e811187f77519dd77c4719cc9f1..65aeb210ea075b7cd3de359dcd88c6370c4504f3 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -187,6 +187,9 @@ class CostGraph { size_t GetNumPairs() const { return edges_.size(); } Status InitSelectedStrategy(); OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const; + // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only + // once (instead of multiple times), this method is used to correct this. + Status CorrectOpsMemoryCost(); // Needed by rec_parser void add_inputs_tensor_name(const std::vector& inputs_tensor_name) { inputs_tensor_name_list_.push_back(inputs_tensor_name); diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 960e13281cf2bbf3dcc7816018950bf4310ee0c7..ecd42db6bbeaa4e074af7b9b5a727a3b745ede9a 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -17,6 +17,7 @@ #include "parallel/auto_parallel/operator_costmodel.h" #include +#include #include "parallel/device_matrix.h" #include "parallel/tensor_layout/tensor_redistribution.h" @@ -24,12 +25,44 @@ namespace mindspore { namespace parallel { void OperatorCost::set_is_parameter(const std::vector& is_parameter) { is_parameter_ = is_parameter; } +void OperatorCost::set_is_parameter_involve(const std::vector& is_parameter_inv) { + is_parameter_involve_ = is_parameter_inv; +} + +void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } + void OperatorCost::SetInputAndOutputTypeLength(const std::vector& input_lengths, const std::vector& output_lengths) { inputs_type_lengths_ = input_lengths; outputs_type_lengths_ = output_lengths; } +double OperatorCost::GetMemoryCost(const std::vector& inputs, + const std::vector& outputs) const { + double result = 0.0; + if (output_parameter_involve_ == 1) { + // When this operator has multiple outputs, they all contributes to the memory. + for (size_t i = 0; i < outputs.size(); ++i) { + result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); + } + bool is_any_para_inv = + std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); + if (is_any_para_inv) { + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_parameter_[i]) { + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } else if (inputs_related_ && (!is_parameter_involve_[i])) { + // When the inputs of this operator are related, and they are not parameter-involved, then they are included + // in the memory cost. + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } + } + } + } + + return result; +} + // return the per device communication cost in the forward phase. double MatMulCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t&) const { @@ -72,11 +105,11 @@ double MatMulCost::GetBackwardCommCost(const std::vector& inputs, co return result; } -// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double MatMulCost::GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t&) const { - // In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) + // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) double result = 0.0; TensorInfo output0 = outputs[0]; Shape input0_slice_shape = inputs[0].slice_shape(); @@ -91,11 +124,11 @@ double MatMulCost::GetForwardComputationCost(const std::vector& inpu return result; } -// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double MatMulCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t& stage_id) const { - // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; if (is_parameter_[1]) { TensorInfo input1 = inputs[1]; // tensor B @@ -145,7 +178,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector& inputs return result; } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double ActivationCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t&) const { @@ -154,7 +187,7 @@ double ActivationCost::GetForwardComputationCost(const std::vector& return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double ActivationCost::GetBackwardComputationCost(const std::vector&, const std::vector&, const int32_t&) const { @@ -189,17 +222,17 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, c return result; } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t&) const { - // In the forward phase, the memory cost = slice(A) + // In the forward phase, the computation cost = slice(A) TensorInfo input0 = inputs[0]; Shape input0_slice_shape = input0.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCost::GetBackwardComputationCost(const std::vector&, const std::vector&, @@ -221,17 +254,15 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector& inputs, +double TmpIdentityCost::GetForwardComputationCost(const std::vector&, const std::vector&, const int32_t&) const { - TensorInfo input0_info = inputs[0]; - Shape input0_slice_shape = input0_info.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); + return 0.0; } -// Return the per memory cost in the backward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, const std::vector&, @@ -239,6 +270,11 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, const std::vector&) const { + return 0.0; +} + double BatchParallelCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t&) const { @@ -284,11 +320,11 @@ double PReLUCost::GetBackwardCommCost(const std::vector& inputs, con return result; } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double PReLUCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t&) const { - // In forward phase, the memory cost = slice(A) + slice(B) + // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + @@ -296,12 +332,12 @@ double PReLUCost::GetForwardComputationCost(const std::vector& input return result; } -// Return the per memory cost in the backward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double PReLUCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t& stage_id) const { - // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; if (is_parameter_[1]) { TensorInfo input1 = inputs[1]; // tensor B @@ -337,16 +373,16 @@ double OneHotCost::GetBackwardCommCost(const std::vector&, const std return 0.0; } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double OneHotCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t&) const { - // In onehot's forward phase, the memory cost = slice(A) + // In onehot's forward phase, the computation cost = slice(A) Shape input0_slice_shape = inputs[0].slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); } -// Return the per memory cost in the backward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double OneHotCost::GetBackwardComputationCost(const std::vector&, const std::vector&, const int32_t&) const { @@ -367,12 +403,12 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector< return 0.0; } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, const int32_t&) const { - // In forward phase, the memory cost = slice(A) + slice(B) + // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + @@ -380,7 +416,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v return result; } -// Return the per memory cost in the backward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector&, const std::vector&, @@ -410,7 +446,7 @@ double ReshapeCost::GetBackwardCommCost(const std::vector&, const st return 0.0; } -// Return the per memory cost in the forward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double ReshapeCost::GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const { @@ -427,7 +463,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector& inp return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); } -// Return the per memory cost in the backward phase. The cost is calculated according to the bytes +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double ReshapeCost::GetBackwardComputationCost(const std::vector&, const std::vector&, diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 685cb259c310d5cb7157477cefc099857feeabce..7dc45bae7165f752359f1551fd052652f4966176 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -43,10 +43,20 @@ double ListProduct(std::vector vec) { // entries timing the length of each entry's data type class OperatorCost { public: - OperatorCost() { + explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { is_parameter_.push_back(false); + is_parameter_involve_.push_back(false); + inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + } + } + OperatorCost() : inputs_related_(false) { + // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked + for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { + is_parameter_.push_back(false); + is_parameter_involve_.push_back(false); inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); } @@ -54,6 +64,8 @@ class OperatorCost { virtual ~OperatorCost() = default; void set_is_parameter(const std::vector& is_parameter); + void set_is_parameter_involve(const std::vector&); + void set_output_parameter_involve(int); void SetInputAndOutputTypeLength(const std::vector& input_lengths, const std::vector& output_lengths); std::vector inputs_type_lengths() const { return inputs_type_lengths_; } std::vector outputs_type_lengths() const { return outputs_type_lengths_; } @@ -72,8 +84,19 @@ class OperatorCost { const std::vector& outputs, const int32_t& stage_id) const = 0; virtual double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const = 0; + // per device PEAK memory cost in a training iteration + // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), + // plus necessary inputs. + virtual double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const; protected: + // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of + // pre-operator that has parameters as input. + std::vector is_parameter_involve_; + int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while + // Mul's two inputs are dependent (related). + bool inputs_related_; // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter std::vector is_parameter_; // for each input and output, the followings record the number of bytes of each element @@ -85,7 +108,8 @@ using OperatorCostPtr = std::shared_ptr; class MatMulCost : public OperatorCost { public: - MatMulCost() = default; + explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + MatMulCost() : OperatorCost(true) {} ~MatMulCost() override = default; // per device communication cost @@ -108,12 +132,12 @@ class MatMulCost : public OperatorCost { double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; }; - using MatMulCostPtr = std::shared_ptr; class ActivationCost : public OperatorCost { public: - ActivationCost() = default; + explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ActivationCost() : OperatorCost(false) {} ~ActivationCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -133,14 +157,14 @@ class ActivationCost : public OperatorCost { double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; }; - using ActivationCostPtr = std::shared_ptr; using TransposeCost = ActivationCost; using TransposeCostPtr = std::shared_ptr; class SoftmaxCost : public OperatorCost { public: - SoftmaxCost() = default; + explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + SoftmaxCost() : OperatorCost(false) {} ~SoftmaxCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -160,12 +184,12 @@ class SoftmaxCost : public OperatorCost { double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t&) const override; }; - using SoftmaxCostPtr = std::shared_ptr; class TmpIdentityCost : public OperatorCost { public: - TmpIdentityCost() = default; + explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + TmpIdentityCost() : OperatorCost(false) {} ~TmpIdentityCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -184,12 +208,15 @@ class TmpIdentityCost : public OperatorCost { const int32_t& stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; + // per device PEAK memory cost in a training iteration + double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override; }; using TmpIdentityCostPtr = std::shared_ptr; class BatchParallelCost : public OperatorCost { public: - BatchParallelCost() = default; + explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + BatchParallelCost() : OperatorCost(false) {} ~BatchParallelCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -217,7 +244,8 @@ using BatchParallelCostPtr = std::shared_ptr; class VirtualDatasetCost : public OperatorCost { public: - VirtualDatasetCost() = default; + explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + VirtualDatasetCost() : OperatorCost(false) {} ~VirtualDatasetCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -244,12 +272,17 @@ class VirtualDatasetCost : public OperatorCost { const int32_t&) const override { return 0.0; } + // per device PEAK memory cost in a training iteration + double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override { + return 0.0; + } }; using VirtualDatasetCostPtr = std::shared_ptr; class GeneratorBaseCost : public OperatorCost { public: - GeneratorBaseCost() = default; + explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GeneratorBaseCost() : OperatorCost(false) {} ~GeneratorBaseCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -283,7 +316,8 @@ using GeneratorBaseCostPtr = std::shared_ptr; class PReLUCost : public OperatorCost { public: - PReLUCost() = default; + explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + PReLUCost() : OperatorCost(true) {} ~PReLUCost() override = default; // per device communication cost @@ -310,7 +344,8 @@ using PReLUCostPtr = std::shared_ptr; class OneHotCost : public OperatorCost { public: - OneHotCost() = default; + explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + OneHotCost() : OperatorCost(true) {} ~OneHotCost() override = default; // per device communication cost @@ -337,7 +372,8 @@ using OneHotCostPtr = std::shared_ptr; class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { public: - SoftmaxCrossEntropyWithLogitsCost() = default; + explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} ~SoftmaxCrossEntropyWithLogitsCost() override = default; // per device communication cost @@ -364,7 +400,8 @@ using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; class ArithmeticCost : public OperatorCost { public: - ArithmeticCost() = default; + explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ArithmeticCost() : OperatorCost(false) {} ~ArithmeticCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -425,7 +463,8 @@ using BiasAddCostPtr = std::shared_ptr; class ReduceMethodCost : public OperatorCost { public: - ReduceMethodCost() = default; + explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ReduceMethodCost() : OperatorCost(true) {} ~ReduceMethodCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -455,7 +494,8 @@ using ReduceMethodCostPtr = std::shared_ptr; class ReduceMeanCost : public ReduceMethodCost { public: - ReduceMeanCost() = default; + explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} + ReduceMeanCost() : ReduceMethodCost(true) {} ~ReduceMeanCost() override = default; double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, @@ -465,7 +505,8 @@ using ReduceMeanCostPtr = std::shared_ptr; class GetNextCost : public OperatorCost { public: - GetNextCost() = default; + explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GetNextCost() : OperatorCost(false) {} ~GetNextCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -499,7 +540,8 @@ using GetNextCostPtr = std::shared_ptr; class DropOutCost : public OperatorCost { public: - DropOutCost() = default; + explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + DropOutCost() : OperatorCost(true) {} ~DropOutCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, @@ -530,7 +572,8 @@ using DropOutCostPtr = std::shared_ptr; class GatherV2Cost : public OperatorCost { public: - GatherV2Cost() = default; + explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GatherV2Cost() : OperatorCost(true) {} ~GatherV2Cost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index b19e38b9106f478c3b6771b56b7e136dbfefa5f6..8dca036f9e8c655898b1d4a7e454e82bddf563a4 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -51,7 +51,7 @@ class Activation : public ActivationBase { public: Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Activation() override = default; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; @@ -102,7 +102,7 @@ class Softmax : public ActivationBase { public: explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Softmax() override = default; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index 435a7ce7933a4d0803ffa6311f286fe2d768add6..376a1fb4cfe93f29173db4c065cebc8a492cce60 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -32,8 +32,8 @@ namespace parallel { class ArithmeticBase : public OperatorInfo { public: ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + const PrimitiveAttrs& attrs, OperatorCostPtr cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ArithmeticBase() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; @@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo { class SubInfo : public ArithmeticBase { public: SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SubInfo() override = default; }; @@ -64,21 +64,21 @@ class TensorAddInfo : public ArithmeticBase { public: TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TensorAddInfo() override = default; }; class MulInfo : public ArithmeticBase { public: MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MulInfo() override = default; }; class DivInfo : public ArithmeticBase { public: DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DivInfo() override = default; }; @@ -86,7 +86,7 @@ class RealDivInfo : public ArithmeticBase { public: RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~RealDivInfo() override = default; }; @@ -94,14 +94,14 @@ class FloorDivInfo : public ArithmeticBase { public: FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~FloorDivInfo() override = default; }; class PowInfo : public ArithmeticBase { public: PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PowInfo() override = default; }; @@ -109,7 +109,7 @@ class GreaterInfo : public ArithmeticBase { public: GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~GreaterInfo() override = default; }; @@ -117,7 +117,7 @@ class AssignSubInfo : public ArithmeticBase { public: AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~AssignSubInfo() override = default; }; } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h index 093bfb8fad785d94ffe6a1be3bcb5bbfe334e438..4cedb9b7b82db3f8c1bcc5ad6f76c71bfc472123 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h @@ -29,9 +29,13 @@ namespace mindspore { namespace parallel { class BatchParallelInfo : public OperatorInfo { public: + BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs, OperatorCostPtr cost) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), dev_num_(1) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + dev_num_(1) {} ~BatchParallelInfo() override = default; Status Init(const StrategyPtr& strategy) override; @@ -58,7 +62,7 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { public: SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {} + : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; void ReComputeBatchSplitFlagList() override; }; diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h index dea5c90c88aff6908aacffaddf95eacefb669b9b..e792858338b5ae3400b05d4ab2a21275b56a555f 100644 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h @@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo { public: BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~BiasAddInfo() override = default; Status Init(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h index 00cc431463dd5d92347f1f562f3f05084eaae529..9ea496e0b02559a1bf75f5400b046bc8d13d42fb 100644 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ #include +#include #include #include #include "ir/value.h" @@ -31,7 +32,7 @@ class EqualInfo : public ArithmeticBase { public: EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~EqualInfo() override = default; }; @@ -39,7 +40,7 @@ class NotEqualInfo : public ArithmeticBase { public: NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~NotEqualInfo() override = default; }; @@ -47,7 +48,7 @@ class MaximumInfo : public ArithmeticBase { public: MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MaximumInfo() override = default; }; @@ -55,7 +56,7 @@ class MinimumInfo : public ArithmeticBase { public: MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MinimumInfo() override = default; }; } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index 7ebe677997deb259c3cb9313186eaa0e38634d34..3b154bd6db009993306228351d482abc26b63f8b 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo { public: DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DropoutDoMaskInfo() override = default; Status Init(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/parallel/ops_info/get_next_info.h index 9a65eff035cdce20b11365bfa2cc071f29f3fcfb..ba209910b73080b24458bbc0e2f5e43b0cc98a90 100644 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.h +++ b/mindspore/ccsrc/parallel/ops_info/get_next_info.h @@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo { public: GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~GetNextInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h index f1c2537a39dcd464b7cb9601f8b635fdbc5a8e46..44fe22ce9064806ec48e5734983d97c82a587e28 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.h @@ -36,7 +36,8 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { public: SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, + std::make_shared(false)) {} ~SoftmaxCrossEntropyWithLogitsInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 848116d68a6106a8ee8d6e87aa2a9d3073ed1fae..e617ae6c2408ff1b8c1dcfdb252334f520431a79 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& // Here, we use the origin outputs_, because we only use the slice size of the output tensor. // It does not matter whether the output tensor is transposed or not. double computation_cost = - cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = - cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = result->communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index 2d3312774dfb442be0aaad5d1ebb513601c60126..8a64fb7206f4e7e6506fd535f86cba7cc1c7754d 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo { public: MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MatMulBase() override = default; Status Init(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h index fec8d9632447752bca542f2f021b8b92f2acd537..a4f00ea0936de13586016a81240c2db258e0373d 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.h @@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo { public: OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~OneHotInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 561628dbb253b2c119289796077e04cc1be7c660..23b6a5190a8bee4ca76ea0618eecbe993c8ede92 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -1035,11 +1035,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { return FAILED; } int32_t stage_id = strategy->GetInputStage(); - double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double computation_cost = + operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = - cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = result->communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); @@ -1096,7 +1097,38 @@ Status OperatorInfo::set_is_parameter(const std::vector& is_parameter) { return FAILED; } is_parameter_ = is_parameter; - cost()->set_is_parameter(is_parameter); + operator_cost()->set_is_parameter(is_parameter); + return SUCCESS; +} + +Status OperatorInfo::CalculateMemoryCost() { + // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to + // calculate memory cost. + if (is_parameter_involve_.size() != is_parameter_.size()) { + MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; + return FAILED; + } + operator_cost()->set_is_parameter_involve(is_parameter_involve_); + operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); + // Set the memory cost in the 'strategy_cost_' + for (auto& swc : strategy_cost_) { + auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); + swc->cost_list[0]->memory_with_reuse_ = mem_cost; + } + return SUCCESS; +} + +Status OperatorInfo::CorrectMemoryCost(size_t input_index) { + for (auto& swc : strategy_cost_) { + double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * + static_cast(operator_cost()->inputs_type_lengths()[input_index]); + swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; + if (swc->cost_list[0]->memory_with_reuse_ < 0) { + MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_ + << ", the parameter memory cost is: " << parameter_mem_cost; + return FAILED; + } + } return SUCCESS; } @@ -1193,7 +1225,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& inpu } inputs_type_lengths_ = input_lengths; outputs_type_lengths_ = output_lengths; - cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); + operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); return SUCCESS; } @@ -1221,7 +1253,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra } double OperatorInfo::GetForwardMemoryCostFromCNode() { - return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); + return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 5fe89a16029f3d637c4a8c32e188d782da9d6189..19e0eeeda1e965860e868c64b4b54ebba8721106 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -60,7 +60,7 @@ class OperatorInfo { outputs_shape_(std::move(outputs_shape)), attrs_(std::move(attrs)), is_alive_(true), - cost_(cost), + operator_cost_(cost), outputs_type_() { std::vector not_parameteter(inputs_shape_.size(), false); is_parameter_ = not_parameteter; @@ -83,8 +83,8 @@ class OperatorInfo { // Given the stage_id (which indicates the number of devices), // generate all strategies for this operator virtual Status GenerateStrategies(int32_t stage_id) = 0; - const OperatorCostPtr& cost() const { return cost_; } - void set_cost(const OperatorCostPtr& cost) { cost_ = cost; } + const OperatorCostPtr& operator_cost() const { return operator_cost_; } + void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; } virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; virtual std::shared_ptr>> GenerateBatchStrategies(); @@ -98,7 +98,7 @@ class OperatorInfo { std::vector> GetStrategyCost() { return strategy_cost_; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. - Status CalculateMemoryCost() const { return SUCCESS; } + Status CalculateMemoryCost(); int ComputeOpAndPrevEdgeParameterInvolved(); ForwardOp forward_op() const { return forward_op_; } @@ -125,7 +125,7 @@ class OperatorInfo { void ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); void ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); void ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); - std::vector GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); } + std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { selected_strategy_ = s_strategy; selected_cost_ = cost; @@ -142,6 +142,10 @@ class OperatorInfo { void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; } void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } const std::string& refkey_parameter_name() const { return refkey_parameter_name_; } + // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated + // multiple times. This method is to correct this, and makes the cost is calulated only once. + Status CorrectMemoryCost(size_t input_index); + int is_output_parameter_involve() const { return is_output_parameter_involve_; } int used_devices() const { return used_devices_; } // needed by rec_parser void set_type(const std::string& type) { type_ = type; } @@ -234,7 +238,7 @@ class OperatorInfo { int32_t used_devices_ = -1; private: - OperatorCostPtr cost_; + OperatorCostPtr operator_cost_; std::vector outputs_type_; }; diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h index bdfb11550b8678c25f82c726950ec5c9a11a6e2f..396407c1ee043c25a8a4a2ed8872d19c0fca769c 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.h @@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo { public: PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PReLUInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc index aa64e72d05e1219b246d524adf7457ad120a3abc..44eab205881a65005c0d001e627b4f37ac724280 100644 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc @@ -109,7 +109,7 @@ Status ReduceMethod::GetAttrs() { } cross_batch_ = cross_batch_iter->second->cast()->value(); } - auto reducemethodcost = std::dynamic_pointer_cast(cost()); + auto reducemethodcost = std::dynamic_pointer_cast(operator_cost()); if (reducemethodcost == nullptr) { MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; return FAILED; diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h index c2ddbc87ce32aaef4b10d32044e8ed69f9328940..2911bdfe10b6b1e83bdc64844138b4bf83eba10e 100644 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h @@ -34,7 +34,7 @@ class ReduceMethod : public OperatorInfo { public: ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~ReduceMethod() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index 38192a5d017749adb18eb8a66bfd4e2ffaf4419f..3864d2b93d6e67f61bdbcb196c6d217e8c04b623 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo { public: ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(0), input_layout_set_flag_(false), output_layout_set_flag_(false) {} diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h index cf850683a61b8b312424ed48b637dbe9fc16c59d..3682fe334fcc3ea94a5844eb986a9834831630ae 100644 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h @@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { public: TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, const std::string& name = IDENTITY_INFO) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TmpIdentityInfo() override = default; Status Init(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h index 2714b352b6f4d99d1a1ad973f1a7487d2171729a..e4e2b90b7bcd67fa2e1de0851b13e681fc557f67 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.h @@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo { public: TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TransposeInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h index b958adeabee0b8572911dd0bc07d0fff027ea6d3..bf17e678a3e8c1c8c12d49dfa2a397d9ad41d53f 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h @@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo { public: VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~VirtualDatasetInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index ae604549722ab3f4f2adf96191cfae1a2f9254a0..a42ce612fb53da7c59e1093c37d77fe1a3eaa994 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -874,11 +874,15 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { // Calculate operators' memory usage if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed."; + MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed."; } // Calculate edges' memory usage if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed."; + MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed."; + } + // Correct memory usage caused by TmpIdentity + if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed."; } } else { MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed."; diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc index 55e6a300e055d08b396e79f4740273f6f91b4656..be5eaa40bacd1e076f421c688131032f81461ada 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc @@ -159,6 +159,7 @@ Status TensorRedistribution::ComputeCost() { backward_comm_cost_ += prod; comm_cost_ += 2.0 * prod; computation_cost_ += prod; + memory_cost_ += prod; } else if (str == CONCAT_BY_AXIS) { // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape // computation cost = before_slice_shape @@ -175,20 +176,25 @@ Status TensorRedistribution::ComputeCost() { if (concat_dim == 0) { // computation cost = all_gather computation_cost_ += prod; + memory_cost_ += prod * dev_num; } else { // computation cost = all_gather + split + concat computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); } } else { // There is only computation cost in SplitByAxis. // computation cost = before_slice_shape computation_cost_ += prod; + // This addtion may be erroneous + memory_cost_ += prod; } } if (reshape_flag()) { Shape prev_slice_shape = from_.slice_shape().array(); double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); computation_cost_ += 2.0 * prev_prod; + memory_cost_ += 2.0 * prev_prod; } return Status::SUCCESS; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index e933b9b8eb9208f8e01187028b0e175d672a1556..7e2b3682e60074b75f564c23bf79c0c49a2836e8 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -42,6 +42,7 @@ class TensorRedistribution { forward_comm_cost_(0.0), backward_comm_cost_(0.0), computation_cost_(0.0), + memory_cost_(0.0), construct_op_flag_(construct_op_flag), keep_reshape_(keep_reshape) {} Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); @@ -54,6 +55,7 @@ class TensorRedistribution { double computation_cost() const { return computation_cost_; } double forward_comm_cost() const { return forward_comm_cost_; } double backward_comm_cost() const { return backward_comm_cost_; } + double memory_cost() const { return memory_cost_; } private: Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, @@ -72,7 +74,12 @@ class TensorRedistribution { double forward_comm_cost_; // backward communication cost double backward_comm_cost_; + // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the + // inputs. double computation_cost_; + // memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is + // calculated by the outputs. + double memory_cost_; bool construct_op_flag_; bool keep_reshape_; }; diff --git a/tests/ut/cpp/parallel/ops_info/activation_test.cc b/tests/ut/cpp/parallel/ops_info/activation_test.cc index a8f8425ae9815d34875347cac7c517b721cf02c6..9af72037991225b7c83f43632e6eebcb793ae5c3 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_test.cc @@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) { act_ptr_->InitForCostModel(sp); std::vector inputs_info = act_ptr_->inputs_tensor_info(); std::vector outputs_info = act_ptr_->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); - ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } } @@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) { soft_ptr_->InitForCostModel(sp); std::vector inputs_info = soft_ptr_->inputs_tensor_info(); std::vector outputs_info = soft_ptr_->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); - ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } } diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index 2fece098e8118691089c9508a5b7e3361f858896..f710f51265f36fec92e6a590e76fb0f88561ca38 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) { matmul1->InitForCostModel(sp); std::vector inputs_info = matmul1->inputs_tensor_info(); std::vector outputs_info = matmul1->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(matmul1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); break; } @@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) { TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); replica_inputs_info.push_back(replica_input1_info); - ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(matmul3->operator_cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); break; } diff --git a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc index 8c956328a776ca2137051bcaf00d8e79ec0c5006..42d292c605db82f77ad1bd6a35de92361a7bd27a 100644 --- a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc @@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) { tensor_add->InitForCostModel(sp); std::vector inputs_info = tensor_add->inputs_tensor_info(); std::vector outputs_info = tensor_add->outputs_tensor_info(); - double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); + double memory_cost0 = tensor_add->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); double memory_cost1 = cost.computation_cost_; bool memory = memory_cost0 - memory_cost1 <= 1.0; - double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); + double comm_cost0 = tensor_add->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); double comm_cost1 = cost.communication_cost_; bool comm = comm_cost0 - comm_cost1 <= 1.0; @@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { tensor_add1->InitForCostModel(sp); std::vector inputs_info = tensor_add1->inputs_tensor_info(); std::vector outputs_info = tensor_add1->outputs_tensor_info(); - double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); + double memory_cost0 = tensor_add1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); double memory_cost1 = cost.computation_cost_; bool memory = memory_cost0 - memory_cost1 <= 1.0; - double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); + double comm_cost0 = tensor_add1->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); double comm_cost1 = cost.communication_cost_; bool comm = comm_cost0 - comm_cost1 <= 1.0; diff --git a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc index 3971a2b47131db22062a1a49be11c41ed38ee9e4..eabac51e1737bde818ccfe3fe438de741c51b661 100644 --- a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc @@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) { identity_ptr->Init(sp); std::vector inputs_info = identity_ptr->inputs_tensor_info(); std::vector outputs_info = identity_ptr->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); - ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } }