diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc index f5cf5069bebc7bbb0af176735c8c204d64a9e174..190f589bb5d726da005074b52c8b0f23169af1e5 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc @@ -23,8 +23,8 @@ namespace mindspore { namespace parallel { void Simplify(CostPtrList* clist_ptrs) { - // Sort the cost_list with the memory_cost increasing, and communication_cost decreasing order. This method - // excludes the cost with greater memory_cost and greater communication_cost. + // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method + // excludes the cost with greater computation_cost_ and greater communication_cost. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} if (!COST_MODEL_SIMPLIFY_CALCULATION) { return; @@ -33,7 +33,7 @@ void Simplify(CostPtrList* clist_ptrs) { std::vector id(clist_ptrs->size()); std::iota(id.begin(), id.end(), size_t(0)); std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { - return clist_ptrs->at(x)->memory_cost_ < clist_ptrs->at(y)->memory_cost_; + return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; }); CostPtrList ret; for (size_t i = 0; i < clist_ptrs->size(); ++i) { @@ -45,8 +45,8 @@ void Simplify(CostPtrList* clist_ptrs) { } void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { - // Sort the cost_list with the memory_cost increasing, and communication_with_partial_para_cost decreasing order. - // This method excludes the cost with greater memory_cost and greater communication_without_para_cost. + // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing + // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. if (!COST_MODEL_SIMPLIFY_CALCULATION) { return; } @@ -54,7 +54,7 @@ void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { std::vector id(clist_ptrs->size()); std::iota(id.begin(), id.end(), size_t(0)); std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { - return clist_ptrs->at(x)->memory_cost_ < clist_ptrs->at(y)->memory_cost_; + return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; }); CostPtrList ret; for (size_t i = 0; i < clist_ptrs->size(); ++i) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 361c19573ff6e6160ae8cda86b7d4c643b27c2fa..229f0fbf5e5f1f8d7c08119cce1ce07e34abd293 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -44,14 +44,18 @@ using RedistributionOpListPtr = std::shared_ptr& decision_ = nullptr) - : memory_cost_(memory), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { + Cost(double computation, double commuication, const std::shared_ptr& decision_ = nullptr) + : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { + memory_with_reuse_ = 0.0; communication_without_parameter_ = 0.0; communication_with_partial_para_ = 0.0; communication_redis_forward_ = 0.0; communication_redis_backward_ = 0.0; } - double memory_cost_; + // 'memory_with_reuse_' calculates the peak memory usage in a training phase + double memory_with_reuse_; + // 'computation_cost_' models the training time of an iteration in a training phase + double computation_cost_; // 'communication_cost_' includes communications from operators (forward and backward) and edges double communication_cost_; // communication_without_parameter_ = communication_cost_ - (backward communication from operators) diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h index c9b6a0731776bfc19016938757e824d2e4169edf..0cb58c49da6992c8880fb2e1e23a27002e196d1a 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h @@ -35,7 +35,7 @@ namespace parallel { // interpretation of 6 operations in costmodel.h. // Phase 2: Search the cost_list in the final graph, and determine the optimal one // Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity -// COST_MODEL_ALPHA * memory_cost + COST_MODEL_BETA * communication_cost +// COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost // Phase 3: Recover the original CostGraph, the determine strategy for each operator // After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying // the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index 6381049f170318ef114043fb4f03a51cc759387c..653f6c903dc49740f80ed72d1bf723cf4983df4b 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -69,7 +69,7 @@ Status Edge::InitEdgeCost() { MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; } MS_EXCEPTION_IF_NULL(cost); - MS_LOG(DEBUG) << "The redistribution cost: memory_cost: " << cost->memory_cost_ + MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_ << ", communication_cost: " << cost->communication_cost_ << ", communication_without_parameter_: " << cost->communication_without_parameter_ << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; @@ -117,9 +117,9 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co double comm_cost = tensor_redistribution.comm_cost(); double forward_comm_cost = tensor_redistribution.forward_comm_cost(); double backward_comm_cost = tensor_redistribution.backward_comm_cost(); - double mem_cost = tensor_redistribution.mem_cost(); + double computation_cost = tensor_redistribution.computation_cost(); - *cost = std::make_shared(type_length * mem_cost, type_length * comm_cost); + *cost = std::make_shared(type_length * computation_cost, type_length * comm_cost); (*cost)->communication_without_parameter_ = type_length * comm_cost; (*cost)->communication_with_partial_para_ = (*cost)->communication_without_parameter_ + @@ -150,26 +150,26 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - std::function recursive = [&](size_t k, double memory, double communication, - double communication_without_para) { - if (k == edges.size()) { - auto decision = std::make_shared(selected_cost_list); - CostPtr new_cost = std::make_shared(memory, communication); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->decision_ptr_ = decision; - result.push_back(new_cost); - return; - } - for (auto& c : all_cost_list[k]) { - MS_EXCEPTION_IF_NULL(c); - selected_cost_list[k] = c; - recursive(k + 1, memory + c->memory_cost_, communication + c->communication_cost_, - communication_without_para + c->communication_without_parameter_); - } - }; + std::function recursive = + [&](size_t k, double computation, double communication, double communication_without_para) { + if (k == edges.size()) { + auto decision = std::make_shared(selected_cost_list); + CostPtr new_cost = std::make_shared(computation, communication); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->decision_ptr_ = decision; + result.push_back(new_cost); + return; + } + for (auto& c : all_cost_list[k]) { + MS_EXCEPTION_IF_NULL(c); + selected_cost_list[k] = c; + recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_, + communication_without_para + c->communication_without_parameter_); + } + }; recursive(0, 0, 0, 0); SimplifyForDreasingCommunicationWithPartialPara(&result); return result; @@ -203,7 +203,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr MS_EXCEPTION_IF_NULL(middle_cost); for (auto& right_cost : right_cost_list) { MS_EXCEPTION_IF_NULL(right_cost); - double memory = left_cost->memory_cost_ + middle_cost->memory_cost_ + right_cost->memory_cost_; + double computation = + left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; double communication = left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; double communication_without_para = left_cost->communication_without_parameter_ + @@ -211,7 +212,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr right_cost->communication_without_parameter_; auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); - auto cost = std::make_shared(memory, communication, decision); + auto cost = std::make_shared(computation, communication, decision); MS_EXCEPTION_IF_NULL(cost); cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h index 1fa49029fae7bc95b36d73d43cce39d51c326b01..eb89466d7cdfcd3a7186711c66d983e98e716914 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h @@ -133,7 +133,7 @@ class Edge { void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. - Status CorrectStrategyCostForMemoryReuse() const { return SUCCESS; } + Status CalculateMemoryCost() const { return SUCCESS; } private: std::string edge_name_; diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 59b9d9e992cf6662dc567135bcf8958d2e2ac03f..88a54662d38a43a589b90f0b791995482798bd77 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -247,7 +247,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: MS_EXCEPTION_IF_NULL(cost1); MS_EXCEPTION_IF_NULL(cost2); MS_EXCEPTION_IF_NULL(cost3); - double memory = cost1->memory_cost_ + cost2->memory_cost_ + cost3->memory_cost_; + double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; double commmunication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; double communication_without_para = cost1->communication_without_parameter_ + @@ -255,7 +255,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: cost3->communication_without_parameter_; auto decision = std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); - auto cost = std::make_shared(memory, commmunication, decision); + auto cost = std::make_shared(computation, commmunication, decision); MS_EXCEPTION_IF_NULL(cost); cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = @@ -282,7 +282,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { for (const auto& cost1 : clist1) { MS_EXCEPTION_IF_NULL(cost1); auto decision = std::make_shared(u_strategy_ptr, cost1); - auto new_cost = std::make_shared(cost1->memory_cost_, cost1->communication_cost_, decision); + auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); MS_EXCEPTION_IF_NULL(new_cost); new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; new_cost->communication_with_partial_para_ = @@ -297,12 +297,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { } CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { - if (cost_list.empty() || cost_list[0]->memory_cost_ >= memory) { + if (cost_list.empty() || cost_list[0]->computation_cost_ >= memory) { return nullptr; } std::function LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { MS_EXCEPTION_IF_NULL(cost_x); - if (init == nullptr || cost_x->memory_cost_ < memory) { + if (init == nullptr || cost_x->computation_cost_ < memory) { init = cost_x; } return init; @@ -313,36 +313,36 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { // Select the cost with minimum training time. Currently, the training time is modeled as = - // costmodel_alpha_ * memory_cost + costmodel_beta_ * communication_with_partial_para_ + // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ if (cost_list.empty()) { MS_LOG(ERROR) << "Final cost list is null."; return nullptr; } CostPtr ret = cost_list[0]; MS_EXCEPTION_IF_NULL(ret); - if (ret->memory_cost_ >= memory) { - MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->memory_cost_ + if (ret->computation_cost_ >= memory) { + MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->computation_cost_ << ", the memory capacity is: " << memory << "."; return nullptr; } - double minimum = costmodel_alpha_ * ret->memory_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; - MS_LOG(INFO) << "minimum: " << minimum << ", memory_cost_: " << ret->memory_cost_ + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; + MS_LOG(INFO) << "minimum: " << minimum << ", computation_cost_: " << ret->computation_cost_ << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ << ", communication_cost_: " << ret->communication_cost_ << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; for (size_t i = 1; i < cost_list.size(); ++i) { MS_EXCEPTION_IF_NULL(cost_list[i]); - if (cost_list[i]->memory_cost_ >= memory) { - MS_LOG(INFO) << "cost_list " << i << " memory_cost_: " << cost_list[i]->memory_cost_ + if (cost_list[i]->computation_cost_ >= memory) { + MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ << ", is larger than the memory capacity: " << memory << "."; break; } - MS_LOG(INFO) << "cost_list " << i << " memory_cost_: " << cost_list[i]->memory_cost_ + MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ << ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_ << ", communication_cost_: " << cost_list[i]->communication_cost_ << ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << "."; - auto tmp = - costmodel_alpha_ * cost_list[i]->memory_cost_ + costmodel_beta_ * cost_list[i]->communication_with_partial_para_; + auto tmp = costmodel_alpha_ * cost_list[i]->computation_cost_ + + costmodel_beta_ * cost_list[i]->communication_with_partial_para_; MS_LOG(INFO) << "tmp: " << tmp; if (minimum > tmp) { minimum = tmp; @@ -363,8 +363,8 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect MS_LOG(ERROR) << "The cost list " << i << " is empty."; return ret; } else { - total_memory += all_cost_list[i][0]->memory_cost_; - minimum += costmodel_alpha_ * all_cost_list[i][0]->memory_cost_ + + total_memory += all_cost_list[i][0]->computation_cost_; + minimum += costmodel_alpha_ * all_cost_list[i][0]->computation_cost_ + costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_; ret[i] = all_cost_list[i][0]; } @@ -381,8 +381,8 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect double tmp_memory = 0.0, tmp_minimum = 0.0; for (size_t i = 0; i < selected_cost_list.size(); ++i) { MS_EXCEPTION_IF_NULL(selected_cost_list[i]); - tmp_memory += selected_cost_list[i]->memory_cost_; - tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->memory_cost_ + + tmp_memory += selected_cost_list[i]->computation_cost_; + tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; } MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum @@ -394,6 +394,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect } return; } + MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; for (auto& c : all_cost_list[k]) { selected_cost_list[k] = c; @@ -814,7 +815,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const for (size_t k = 0; k < tar_cost_list.size(); ++k) { auto& tar_cost = tar_cost_list[k]; MS_EXCEPTION_IF_NULL(tar_cost); - double memory = op_cost->memory_cost_ + edge_cost->memory_cost_ + tar_cost->memory_cost_; + double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; double communication = op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; double communication_without_para = op_cost->communication_without_parameter_ + @@ -823,7 +824,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const auto decision = std::make_shared(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); - auto new_cost = std::make_shared(memory, communication, decision); + auto new_cost = std::make_shared(computation, communication, decision); MS_EXCEPTION_IF_NULL(new_cost); new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = @@ -891,7 +892,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str for (size_t k = 0; k < tar_cost_list.size(); ++k) { auto& tar_cost = tar_cost_list[k]; MS_EXCEPTION_IF_NULL(tar_cost); - double memory = contract_op_cost->memory_cost_ + edge_cost->memory_cost_ + tar_cost->memory_cost_; + double computation = + contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; double communication = contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; double communication_without_para = contract_op_cost->communication_without_parameter_ + @@ -900,7 +902,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str auto decision = std::make_shared(contract_op_stra, contract_op_cost, edge_cost, target_op_stra, tar_cost); - auto new_cost = std::make_shared(memory, communication, decision); + auto new_cost = std::make_shared(computation, communication, decision); new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); @@ -963,9 +965,9 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, MS_EXCEPTION_IF_NULL(left_edge_cost); for (auto& left_node_cost : left_node_clist_origin) { MS_EXCEPTION_IF_NULL(left_node_cost); - double new_memory_cost = elimi_op_cost->memory_cost_ + left_edge_cost->memory_cost_ + - left_node_cost->memory_cost_ + right_edge_cost->memory_cost_ + - right_op_cost->memory_cost_; + double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_ + + right_op_cost->computation_cost_; double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ + right_op_cost->communication_cost_; @@ -977,7 +979,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, auto decision = std::make_shared(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra, right_op_cost); - auto new_cost = std::make_shared(new_memory_cost, new_commu_cost, decision); + auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); new_cost->communication_without_parameter_ = new_commu_without; new_cost->communication_with_partial_para_ = new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); @@ -1082,11 +1084,12 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n succ_edges_costs[0] = first_succ_edge_cost; succ_nodes_costs[0] = first_succ_node_cost; - double memory_cost = merged_node_cost->memory_cost_, commu_cost = merged_node_cost->communication_cost_, + double computation_cost = merged_node_cost->computation_cost_, + commu_cost = merged_node_cost->communication_cost_, commu_without = merged_node_cost->communication_without_parameter_; for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); - memory_cost += succ_edges_costs[i]->memory_cost_ + succ_nodes_costs[i]->memory_cost_; + computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; commu_without += succ_edges_costs[i]->communication_without_parameter_ + succ_nodes_costs[i]->communication_without_parameter_; @@ -1094,7 +1097,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n auto decision = std::make_shared(merged_op_stra, merged_node_cost, succ_edges_costs, succ_nodes_stras, succ_nodes_costs); - auto new_cost = std::make_shared(memory_cost, commu_cost, decision); + auto new_cost = std::make_shared(computation_cost, commu_cost, decision); new_cost->communication_without_parameter_ = commu_without; new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); first_succ_node_clist_new->emplace_back(std::move(new_cost)); @@ -1210,36 +1213,6 @@ Status CostGraph::InitSelectedStrategy() { return SUCCESS; } -Status CostGraph::CorrectOpsStrategyCostForMultiOutputUse() { - for (auto& op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->GetAliveSuccEdges().size() > 1) { - // Filter out the case of a output being used by multiple operators - std::map output_count; - for (size_t i = 0; i < op->GetAliveSuccEdges().size(); ++i) { - auto output_index = op->GetAliveSuccEdges()[i]->prev_op_output_index(); - output_count[output_index]++; - } - for (size_t i = 0; i < op->GetAliveSuccEdges().size(); ++i) { - auto output_index = op->GetAliveSuccEdges()[i]->prev_op_output_index(); - if (output_count[output_index] <= 1) { - continue; - } - auto next_op = op->GetAliveSuccEdges()[i]->next_operator(); - MS_EXCEPTION_IF_NULL(next_op); - auto input_index = op->GetAliveSuccEdges()[i]->next_op_input_index(); - if (next_op->CorrectStrategyCostForMultiOutputUse(input_index) != SUCCESS) { - MS_LOG(ERROR) << "The operator name: " << op->name() << ", the next operator name: " << next_op->name() - << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; - return FAILED; - } - output_count[output_index]--; - } - } - } - return SUCCESS; -} - Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { for (auto& op : ops_) { MS_EXCEPTION_IF_NULL(op); @@ -1252,23 +1225,23 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { return SUCCESS; } -Status CostGraph::CorrectOpsStrategyCostForMemoryReuse() { +Status CostGraph::CalculateOpsMemoryCost() { for (auto& op : ops_) { MS_EXCEPTION_IF_NULL(op); - if (op->CorrectStrategyCostForMemoryReuse() != SUCCESS) { - MS_LOG(ERROR) << "Correcting Operator: " << op->name() << " cost for memory reuse failed."; + if (op->CalculateMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; return FAILED; } } return SUCCESS; } -Status CostGraph::CorrectEdgesStrategyCostForMemoryReuse() { +Status CostGraph::CalculateEdgesMemoryCost() { for (auto& edge_pair : edges_) { const auto& edges = edge_pair.second; for (auto& one_edge : edges) { - if (one_edge->CorrectStrategyCostForMemoryReuse() != SUCCESS) { - MS_LOG(ERROR) << "Correcting Edge: " << one_edge->edge_name() << " cost for memory reuse failed."; + if (one_edge->CalculateMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; return FAILED; } } diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index e4cbdffb6100f2779602986f891fe9bfb05783b2..c149534826474e811187f77519dd77c4719cc9f1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -175,16 +175,12 @@ class CostGraph { void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, const CostPtrList&, std::vector, CostPtrList&, CostPtrList&, CostPtrList*); - - // When a output of a operator is being used by multiple operators, the memory cost of this part should be calculated - // only once. This method is for correcting the 'strategy_cost_' for operators - Status CorrectOpsStrategyCostForMultiOutputUse(); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // the memory cost can be resused. - Status CorrectOpsStrategyCostForMemoryReuse(); + Status CalculateOpsMemoryCost(); // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // the memory cost can be resused. - Status CorrectEdgesStrategyCostForMemoryReuse(); + Status CalculateEdgesMemoryCost(); Status ComputeOpsAndEdgesParameterInvolved(); std::vector GetOperators() const { return ops_; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 6958932fd6d69473e675de64cf96c3e294b0292b..7c17b499b11147244320a5cb3856e60a2129af94 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -74,8 +74,8 @@ double MatMulCost::GetBackwardCommCost(const std::vector& inputs, co // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t&) const { +double MatMulCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector& outputs, const int32_t&) const { // In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) double result = 0.0; TensorInfo output0 = outputs[0]; @@ -93,8 +93,8 @@ double MatMulCost::GetForwardMemoryCost(const std::vector& inputs, c // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetBackwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { +double MatMulCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t& stage_id) const { // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) double result = 0.0; if (is_parameter_[1]) { @@ -147,8 +147,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector& inputs // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { +double ActivationCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); @@ -156,8 +156,8 @@ double ActivationCost::GetForwardMemoryCost(const std::vector& input // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetBackwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const { +double ActivationCost::GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const { return 0.0; } @@ -191,8 +191,8 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { +double SoftmaxCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { // In the forward phase, the memory cost = slice(A) TensorInfo input0 = inputs[0]; Shape input0_slice_shape = input0.slice_shape(); @@ -201,8 +201,9 @@ double SoftmaxCost::GetForwardMemoryCost(const std::vector& inputs, // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetBackwardMemoryCost(const std::vector&, - const std::vector&, const int32_t&) const { +double SoftmaxCost::GetBackwardComputationCost(const std::vector&, + const std::vector&, + const int32_t&) const { return 0.0; } @@ -222,9 +223,9 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector& inputs, - const std::vector&, - const int32_t&) const { +double TmpIdentityCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector&, + const int32_t&) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); @@ -232,15 +233,15 @@ double TmpIdentityCost::GetForwardMemoryCost(const std::vector&, - const std::vector&, - const int32_t&) const { +double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, + const std::vector&, + const int32_t&) const { return 0.0; } -double BatchParallelCost::GetForwardMemoryCost(const std::vector& inputs, - const std::vector&, - const int32_t&) const { +double BatchParallelCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector&, + const int32_t&) const { double cost = 0.0; for (size_t i = 0; i < inputs.size(); ++i) { cost += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); @@ -248,9 +249,9 @@ double BatchParallelCost::GetForwardMemoryCost(const std::vector&, - const std::vector&, - const int32_t&) const { +double BatchParallelCost::GetBackwardComputationCost(const std::vector&, + const std::vector&, + const int32_t&) const { return 0.0; } @@ -285,8 +286,8 @@ double PReLUCost::GetBackwardCommCost(const std::vector& inputs, con // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { +double PReLUCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { // In forward phase, the memory cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -297,9 +298,9 @@ double PReLUCost::GetForwardMemoryCost(const std::vector& inputs, co // Return the per memory cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetBackwardMemoryCost(const std::vector& inputs, - const std::vector&, - const int32_t& stage_id) const { +double PReLUCost::GetBackwardComputationCost(const std::vector& inputs, + const std::vector&, + const int32_t& stage_id) const { // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) double result = 0.0; if (is_parameter_[1]) { @@ -338,8 +339,8 @@ double OneHotCost::GetBackwardCommCost(const std::vector&, const std // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { +double OneHotCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { // In onehot's forward phase, the memory cost = slice(A) Shape input0_slice_shape = inputs[0].slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); @@ -347,8 +348,8 @@ double OneHotCost::GetForwardMemoryCost(const std::vector& inputs, c // Return the per memory cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetBackwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const { +double OneHotCost::GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const { return 0.0; } @@ -368,8 +369,9 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector< // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetForwardMemoryCost(const std::vector& inputs, - const std::vector&, const int32_t&) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector&, + const int32_t&) const { // In forward phase, the memory cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -380,8 +382,9 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardMemoryCost(const std::vector // Return the per memory cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardMemoryCost(const std::vector&, - const std::vector&, const int32_t&) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector&, + const std::vector&, + const int32_t&) const { return 0.0; } @@ -409,8 +412,8 @@ double ReshapeCost::GetBackwardCommCost(const std::vector&, const st // Return the per memory cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const { +double ReshapeCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector& outputs, const int32_t& stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); @@ -421,26 +424,27 @@ double ReshapeCost::GetForwardMemoryCost(const std::vector& inputs, if (tensor_redistribution.ComputeCost() == FAILED) { MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; } - return (inputs_type_lengths_[0] * tensor_redistribution.mem_cost()); + return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); } // Return the per memory cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetBackwardMemoryCost(const std::vector&, - const std::vector&, const int32_t&) const { +double ReshapeCost::GetBackwardComputationCost(const std::vector&, + const std::vector&, + const int32_t&) const { return 0.0; } -double ArithmeticCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { +double ArithmeticCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { double result; result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); return result; } -double ArithmeticCost::GetBackwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { +double ArithmeticCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t& stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -533,15 +537,15 @@ double L2NormalizeCost::GetBackwardCommCost(const std::vector& input return result; } -double L2NormalizeCost::GetForwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { +double L2NormalizeCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); } -double L2NormalizeCost::GetBackwardMemoryCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { +double L2NormalizeCost::GetBackwardComputationCost(const std::vector& inputs, + const std::vector&, const int32_t& stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -618,8 +622,9 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector& inpu return result; } -double ReduceMethodCost::GetForwardMemoryCost(const std::vector& inputs, - const std::vector& outputs, const int32_t& stage_id) const { +double ReduceMethodCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector& outputs, + const int32_t& stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -640,8 +645,9 @@ double ReduceMethodCost::GetForwardMemoryCost(const std::vector& inp return result; } -double ReduceMeanCost::GetForwardMemoryCost(const std::vector& inputs, - const std::vector& outputs, const int32_t& stage_id) const { +double ReduceMeanCost::GetForwardComputationCost(const std::vector& inputs, + const std::vector& outputs, + const int32_t& stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 9fb86d467ed72b009a5f55c0dec2b293e8b040d5..8f0099bba3ac9fccbfa29db5fccfa4e9f74c2224 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -65,12 +65,12 @@ class OperatorCost { virtual double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const = 0; // per device computation cost - virtual double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; - virtual double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; - virtual double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; + virtual double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const = 0; + virtual double GetForwardComputationCost(const std::vector& inputs, + const std::vector& outputs, const int32_t& stage_id) const = 0; + virtual double GetBackwardComputationCost(const std::vector& inputs, + const std::vector& outputs, const int32_t& stage_id) const = 0; protected: // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter @@ -96,14 +96,14 @@ class MatMulCost : public OperatorCost { const int32_t& stage_id) const override; // per device computation cost - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using MatMulCostPtr = std::shared_ptr; @@ -121,14 +121,14 @@ class ActivationCost : public OperatorCost { const int32_t& stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using ActivationCostPtr = std::shared_ptr; @@ -146,14 +146,14 @@ class SoftmaxCost : public OperatorCost { const int32_t& stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t&) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t&) const override; }; using SoftmaxCostPtr = std::shared_ptr; @@ -171,14 +171,14 @@ class TmpIdentityCost : public OperatorCost { const int32_t& stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using TmpIdentityCostPtr = std::shared_ptr; @@ -199,14 +199,14 @@ class BatchParallelCost : public OperatorCost { const int32_t&) const override { return 0.0; } - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using BatchParallelCostPtr = std::shared_ptr; @@ -227,16 +227,16 @@ class VirtualDatasetCost : public OperatorCost { const int32_t&) const override { return 0.0; } - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } - double GetBackwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } }; @@ -259,18 +259,18 @@ class GeneratorBaseCost : public OperatorCost { const int32_t&) const override { return 0.0; } - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } }; @@ -292,14 +292,14 @@ class PReLUCost : public OperatorCost { const int32_t& stage_id) const override; // per device computation cost - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using PReLUCostPtr = std::shared_ptr; @@ -319,14 +319,14 @@ class OneHotCost : public OperatorCost { const int32_t& stage_id) const override; // per device computation cost - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using OneHotCostPtr = std::shared_ptr; @@ -346,14 +346,14 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { const int32_t& stage_id) const override; // per device computation cost - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; @@ -376,16 +376,16 @@ class ReshapeCost : public OperatorCost { const int32_t& stage_id) const override; // per device computation cost - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using ReshapeCostPtr = std::shared_ptr; @@ -405,14 +405,14 @@ class ArithmeticCost : public OperatorCost { double GetBackwardCommCost(const std::vector&, const std::vector&, const int32_t&) const override; - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using ArithmeticCostPtr = std::shared_ptr; @@ -431,14 +431,14 @@ class L2NormalizeCost : public OperatorCost { } double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using L2NormalizeCostPtr = std::shared_ptr; @@ -455,14 +455,14 @@ class ReduceMethodCost : public OperatorCost { const int32_t& stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, const int32_t& stage_id) const override; - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } void set_cross_batch(bool cb) { cross_batch_ = cb; } @@ -477,8 +477,8 @@ class ReduceMeanCost : public ReduceMethodCost { ReduceMeanCost() = default; ~ReduceMeanCost() override = default; - double GetForwardMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; }; using ReduceMeanCostPtr = std::shared_ptr; @@ -499,18 +499,18 @@ class GetNextCost : public OperatorCost { const int32_t&) const override { return 0.0; } - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardMemoryCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const override { return 0.0; } }; diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index ad6409be0a85ab4f78c9e9465f5f07f3e5c61165..2b02dc100d420bc86d1f0279c32eae2c1e259f8d 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -592,10 +592,10 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& int32_t stage_id = strategy->GetInputStage(); // Here, we use the origin outputs_, because we only use the slice size of the output tensor. // It does not matter whether the output tensor is transposed or not. - double memory_cost = - matmulcost_ptr->GetForwardMemoryCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + double computation_cost = + matmulcost_ptr->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(memory_cost, communication_cost); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = @@ -604,7 +604,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& // Breaking ties for preferring data parallelization BreakingTiesForPerferringDataParallel(strategy, result); - MS_LOG(DEBUG) << name_ << " : memory_cost: " << result->memory_cost_ + MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_ << ", communication_cost: " << result->communication_cost_ << ", communication_without_parameter_: " << result->communication_without_parameter_ << ", communication_with_partial_para_: " << result->communication_with_partial_para_; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 8b96425bf751ee61766b6f31e359b0221551ce46..11c518d844b79e10c108d78717af685da96e0224 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -1034,9 +1034,10 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { return FAILED; } int32_t stage_id = strategy->GetInputStage(); - double memory_cost = GetOperatorCost()->GetForwardMemoryCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double computation_cost = + GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(memory_cost, communication_cost); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = @@ -1056,22 +1057,6 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { return SUCCESS; } -Status OperatorInfo::CorrectStrategyCostForMultiOutputUse(size_t input_index) { - for (auto& swc : strategy_cost_) { - double parameter_memory_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * - static_cast(GetOperatorCost()->inputs_type_lengths()[input_index]); - // remove the parameter memory cost - swc->cost_list[0]->memory_cost_ -= parameter_memory_cost; - if (swc->cost_list[0]->memory_cost_ < -1) { - MS_LOG(ERROR) << "The memory cost after correction is " << swc->cost_list[0]->memory_cost_ - << ", the parameter_memory_cost is " << parameter_memory_cost; - return FAILED; - } - } - corrected_input_indices_.push_back(input_index); - return SUCCESS; -} - int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { if (is_output_parameter_involve_ != -1) { return is_output_parameter_involve_; @@ -1217,7 +1202,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra CheckGlobalDeviceManager(); auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { - cost->memory_cost_ -= 1.0; + cost->computation_cost_ -= 1.0; cost->communication_cost_ -= 1.0; cost->communication_with_partial_para_ -= 1.0; cost->communication_without_parameter_ -= 1.0; @@ -1226,7 +1211,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra } double OperatorInfo::GetForwardMemoryCostFromCNode() { - return GetOperatorCost()->GetForwardMemoryCost(inputs_tensor_info_, outputs_tensor_info_, 0); + return GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index cc70f1b8702ef9314c19c03d62329498c8167238..e7b8af0a7ed7b351efd157f522eb799671ef136b 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -87,13 +87,9 @@ class OperatorInfo { // is checked Status SetCostUnderStrategyBase(const StrategyPtr& strategy); std::vector> GetStrategyCost() { return strategy_cost_; } - // In the case of a Parameter (or a output) being used by multiple operators, the memory cost induced by - // the parameter (or a output) should be calculated only once. This method is used to - // remove this part from the 'strategy_cost_'. - Status CorrectStrategyCostForMultiOutputUse(size_t input_index); // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. - Status CorrectStrategyCostForMemoryReuse() const { return SUCCESS; } + Status CalculateMemoryCost() const { return SUCCESS; } int ComputeOpAndPrevEdgeParameterInvolved(); ForwardOp forward_op() const { return forward_op_; } diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 50e6a1e84e56b30abdf32a6d6946d32a8c4e3e2e..d7d48c35bb38f83f50e9fcc22a3518ff66acf2d8 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -387,7 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_cnode(cnode); // If no strategy has been configured for this operator, then candidate strategies are generated for - // auto-strategy searchingm if this primitive is Cast, we ignore the user-specified strategy + // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy if (!StrategyFound(attrs) || prim->name() == CAST) { // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // BatchParallelInfo operator @@ -600,13 +600,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { } MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); } - // For the case of a output being used by multiple subsequent operators, the output induced memory cost should be - // calculated only once. This method is for correct the operators' memory cost calculation. - if (entire_costgraph->CorrectOpsStrategyCostForMultiOutputUse() != SUCCESS) { - MS_LOG(EXCEPTION) << "Correcting strategy_cost_ for operators failed."; - } else { - MS_LOG(INFO) << "Correcting strategy_cost_ for operators succeeded."; - } + MS_LOG(INFO) << "Constructing edges for cost graph ends."; } @@ -803,14 +797,6 @@ void AugmentCostGraph(const std::vector &all_nodes) { std::shared_ptr edge_ptr = std::make_shared( edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); - // Correct the memory calculation for a parameter being used by multiple operators. The parameter is calculated - // only once - if (target_cnode->operator_info()->CorrectStrategyCostForMultiOutputUse(IntToSize(input_index - 1)) != SUCCESS) { - MS_LOG(EXCEPTION) << "Correcting strategy_cost_ failed : " << prim->name(); - } else { - MS_LOG(INFO) << "Correcting strategy_cost_ succeeded. " << prim->name(); - } - if (edge_ptr->InitEdgeCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Edge cost initialization failed"; } @@ -840,7 +826,7 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity // operator for this Parameter, and add an edge for the use of this Parameter by each // subsequent operator; - // Step 3.1: Correct the memory calculation for memory reuse + // Step 3.1: Calculate memory usage // Step 4: Run the Dynamic Programming algorithm: // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input @@ -867,14 +853,14 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() << " operators, and " << entire_costgraph->GetNumPairs() << " edges."; - // Step 3.1: Correcting calculation for memory reuse + // Step 3.1: Calculate the memory usage if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { - // Correcting operators' memory usage - if (entire_costgraph->CorrectOpsStrategyCostForMemoryReuse() != SUCCESS) { + // Calculate operators' memory usage + if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed."; } - // Correcting edges' memory usage - if (entire_costgraph->CorrectEdgesStrategyCostForMemoryReuse() != SUCCESS) { + // Calculate edges' memory usage + if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed."; } } else { diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc index 93bda5da8145ee0818b5feec94b0612be90ca25c..55e6a300e055d08b396e79f4740273f6f91b4656 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc @@ -144,7 +144,7 @@ Status TensorRedistribution::ComputeCost() { MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; return Status::FAILED; } - // Compute redistribution communication cost and memory cost + // Compute redistribution communication cost and computation cost for (auto& op_cost : operator_list_) { OperatorR op = op_cost.first; Shape slice_shape = op_cost.second; @@ -154,14 +154,14 @@ Status TensorRedistribution::ComputeCost() { if (str == PERMUTE_BY_AXIS) { // The shape does not change after PermuteByAxis operation. // communication cost = all_to_all + all_to_all = 2 * slice_shape - // memory cost = slice_shape + // computation cost = slice_shape forward_comm_cost_ += prod; backward_comm_cost_ += prod; comm_cost_ += 2.0 * prod; - mem_cost_ += prod; + computation_cost_ += prod; } else if (str == CONCAT_BY_AXIS) { // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape - // memory cost = before_slice_shape + // computation cost = before_slice_shape if (op.second.size() < 3) { MS_LOG(ERROR) << "op.second size should not be less than 3!"; return Status::FAILED; @@ -173,22 +173,22 @@ Status TensorRedistribution::ComputeCost() { comm_cost_ += prod * (dev_num + 1.0); int32_t concat_dim = op.second[0]; if (concat_dim == 0) { - // memory cost = all_gather - mem_cost_ += prod; + // computation cost = all_gather + computation_cost_ += prod; } else { - // memory cost = all_gather + split + concat - mem_cost_ += (prod + prod * dev_num + prod * dev_num); + // computation cost = all_gather + split + concat + computation_cost_ += (prod + prod * dev_num + prod * dev_num); } } else { - // There is only memory cost in SplitByAxis. - // memory cost = before_slice_shape - mem_cost_ += prod; + // There is only computation cost in SplitByAxis. + // computation cost = before_slice_shape + computation_cost_ += prod; } } if (reshape_flag()) { Shape prev_slice_shape = from_.slice_shape().array(); double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); - mem_cost_ += 2.0 * prev_prod; + computation_cost_ += 2.0 * prev_prod; } return Status::SUCCESS; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index 38fb5959ad91ffe5776393244833d59402fec310..e933b9b8eb9208f8e01187028b0e175d672a1556 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -41,7 +41,7 @@ class TensorRedistribution { comm_cost_(0.0), forward_comm_cost_(0.0), backward_comm_cost_(0.0), - mem_cost_(0.0), + computation_cost_(0.0), construct_op_flag_(construct_op_flag), keep_reshape_(keep_reshape) {} Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); @@ -51,7 +51,7 @@ class TensorRedistribution { bool reshape_flag() const { return reshape_flag_; } Status ComputeCost(); double comm_cost() const { return comm_cost_; } - double mem_cost() const { return mem_cost_; } + double computation_cost() const { return computation_cost_; } double forward_comm_cost() const { return forward_comm_cost_; } double backward_comm_cost() const { return backward_comm_cost_; } @@ -66,10 +66,13 @@ class TensorRedistribution { RankList dev_list_; OperatorList operator_list_; bool reshape_flag_; + // communication cost double comm_cost_; + // forward communication cost double forward_comm_cost_; + // backward communication cost double backward_comm_cost_; - double mem_cost_; + double computation_cost_; bool construct_op_flag_; bool keep_reshape_; }; diff --git a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc index 83a9eceaccbd83386e16bf59a8ec7cc7efcd5e13..415a1fdd558dcf4e4eecb131bb697a53fef4a7e2 100644 --- a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc @@ -322,8 +322,8 @@ TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) { auto ret_list = entire_cost_graph.SelectCostListWithMinTrainingTimeMultiple(all_list, memory); ASSERT_EQ(ret_list.size(), 2); - ASSERT_DOUBLE_EQ(ret_list[0]->memory_cost_, 10); - ASSERT_DOUBLE_EQ(ret_list[1]->memory_cost_, 1010); + ASSERT_DOUBLE_EQ(ret_list[0]->computation_cost_, 10); + ASSERT_DOUBLE_EQ(ret_list[1]->computation_cost_, 1010); } TEST_F(TestCostGraph, test_CheckOpElimination) { diff --git a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc index 3bd65c049c546abaff1260e8df216e38b92d5725..919c5b43eca6b0b08ce2309979c2cfbfb856279f 100644 --- a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc @@ -76,8 +76,8 @@ TEST_F(TestMatMulCost, test_CostGeneration) { mmcost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); mmcost_.GetForwardCommCost(inputs, outputs, 0); mmcost_.GetBackwardCommCost(inputs, outputs, 0); - mmcost_.GetForwardMemoryCost(inputs, outputs, 0); - mmcost_.GetBackwardMemoryCost(inputs, outputs, 0); + mmcost_.GetForwardComputationCost(inputs, outputs, 0); + mmcost_.GetForwardComputationCost(inputs, outputs, 0); } class TestActivationCost : public UT::Common { @@ -128,8 +128,8 @@ TEST_F(TestActivationCost, test_CostGeneration) { std::vector inputs_length = {4, 4}; std::vector outputs_length = {4}; ac_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); - ac_cost_.GetForwardMemoryCost(inputs, outputs, 0); - ac_cost_.GetBackwardMemoryCost(inputs, outputs, 0); + ac_cost_.GetForwardComputationCost(inputs, outputs, 0); + ac_cost_.GetBackwardComputationCost(inputs, outputs, 0); } class TestPReLUCost : public UT::Common { @@ -184,8 +184,8 @@ TEST_F(TestPReLUCost, test_CostGeneration) { prelu_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); double BCC, FMC, GMC; BCC = prelu_cost_.GetBackwardCommCost(inputs, outputs, 0); - FMC = prelu_cost_.GetForwardMemoryCost(inputs, outputs, 0); - GMC = prelu_cost_.GetBackwardMemoryCost(inputs, outputs, 0); + FMC = prelu_cost_.GetForwardComputationCost(inputs, outputs, 0); + GMC = prelu_cost_.GetBackwardComputationCost(inputs, outputs, 0); ASSERT_EQ(BCC, 32 * 4); ASSERT_EQ(FMC, 8 * 32 * 8 * 8 * 4 + 32 * 4); ASSERT_EQ(GMC, 128); diff --git a/tests/ut/cpp/parallel/ops_info/activation_test.cc b/tests/ut/cpp/parallel/ops_info/activation_test.cc index 149aa9d5af7e95d17f74b5b03b80d3ec19c8a9e3..5d18c5372f709cadef7807466f10526060c361dc 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_test.cc @@ -84,8 +84,8 @@ TEST_F(TestActivation, test_activation_strategies) { act_ptr_->InitForCostModel(sp); std::vector inputs_info = act_ptr_->inputs_tensor_info(); std::vector outputs_info = act_ptr_->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), - cost.memory_cost_); + ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + cost.computation_cost_); ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } @@ -109,8 +109,8 @@ TEST_F(TestActivation, test_softmax_strategies) { soft_ptr_->InitForCostModel(sp); std::vector inputs_info = soft_ptr_->inputs_tensor_info(); std::vector outputs_info = soft_ptr_->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), - cost.memory_cost_); + ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + cost.computation_cost_); ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index 978b792a0c3d282af51f1a5eadfb6231962fad26..99ca9f8e0ede3004d5a969e1b0ada1ad82875b92 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -569,8 +569,8 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) { matmul1->InitForCostModel(sp); std::vector inputs_info = matmul1->inputs_tensor_info(); std::vector outputs_info = matmul1->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), - cost.memory_cost_); + ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + cost.computation_cost_); break; } } @@ -599,8 +599,8 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) { TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); replica_inputs_info.push_back(replica_input1_info); - ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetMemoryCost(replica_inputs_info, outputs_info, sp->GetInputStage()), - cost.memory_cost_); + ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), + cost.computation_cost_); break; } } diff --git a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc index e7736a4b3ef9e05839652e71ccad3ce8e3e2404f..6cb9739b1cda37b22060e3aa39eafbba8a1ae03d 100644 --- a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc @@ -188,8 +188,8 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) { tensor_add->InitForCostModel(sp); std::vector inputs_info = tensor_add->inputs_tensor_info(); std::vector outputs_info = tensor_add->outputs_tensor_info(); - double memory_cost0 = tensor_add->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()); - double memory_cost1 = cost.memory_cost_; + double memory_cost0 = tensor_add->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); + double memory_cost1 = cost.computation_cost_; bool memory = memory_cost0 - memory_cost1 <= 1.0; double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); @@ -210,8 +210,8 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { tensor_add1->InitForCostModel(sp); std::vector inputs_info = tensor_add1->inputs_tensor_info(); std::vector outputs_info = tensor_add1->outputs_tensor_info(); - double memory_cost0 = tensor_add1->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()); - double memory_cost1 = cost.memory_cost_; + double memory_cost0 = tensor_add1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); + double memory_cost1 = cost.computation_cost_; bool memory = memory_cost0 - memory_cost1 <= 1.0; double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); diff --git a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc index ce1238baeba6f7dda2d56d2b7e10c0fcf34c4241..043746498f19cd7c525384f0cca41b1ae66a505f 100644 --- a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc @@ -145,8 +145,8 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) { identity_ptr->Init(sp); std::vector inputs_info = identity_ptr->inputs_tensor_info(); std::vector outputs_info = identity_ptr->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), - cost.memory_cost_); + ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + cost.computation_cost_); ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); }