diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc index ad3a3a1298f04a9b53c54f914871831af64f80c0..65e9acf71474034ed2abd9335db2c41f56168fbd 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc @@ -23,8 +23,17 @@ namespace mindspore { namespace parallel { void Simplify(CostPtrList *clist_ptrs) { - // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method - // excludes the cost with greater computation_cost_ and greater communication_cost. + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); + } else { + // inference phase + SimplifyForDecreasingCommunicationForward(clist_ptrs); + } +} +void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { + // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method + // excludes the cost with greater computation_cost_ and greater communication_forward. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} if (!COST_MODEL_SIMPLIFY_CALCULATION) { return; @@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) { }); CostPtrList ret; for (size_t i = 0; i < clist_ptrs->size(); ++i) { - if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) { + if ((ret.size() == size_t(0)) || + (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { ret.emplace_back(std::move(clist_ptrs->at(id[i]))); } } *clist_ptrs = std::move(ret); } -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { +void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. if (!COST_MODEL_SIMPLIFY_CALCULATION) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 2cb24dd7f36bf199f88370856ae674f4ca3e6fe7..adae0688d61e8fee84aa781a6a20eb43d8f9cbd5 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -51,18 +51,22 @@ struct Cost { communication_with_partial_para_ = 0.0; communication_redis_forward_ = 0.0; communication_redis_backward_ = 0.0; + communication_forward_ = 0.0; } // 'memory_with_reuse_' calculates the peak memory usage in a training phase double memory_with_reuse_; - // 'computation_cost_' models the training time of an iteration in a training phase + // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated + // by ONLY forward phase double computation_cost_; - // 'communication_cost_' includes communications from operators (forward and backward) and edges + // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) double communication_cost_; // communication_without_parameter_ = communication_cost_ - (backward communication from operators) double communication_without_parameter_; // communication_with_partial_para_ = // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) double communication_with_partial_para_; + // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. + double communication_forward_; double communication_redis_forward_; double communication_redis_backward_; std::shared_ptr decision_ptr_; @@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr; using FinalSingleDecisionPtr = std::shared_ptr; void Simplify(CostPtrList *clist); -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist); +void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); +void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); void RefineForPracticalCost(const CostPtr &, bool is_redistribution); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index 6973830779b640dd06e17d0efdceebbcdf2b0537..9189689f5252d00b2b228da78a92401f8e917a00 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() { << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; // refine communication cost calculation for practice RefineForPracticalCost(cost, true); + cost->communication_forward_ = cost->communication_redis_forward_; CostPtrKey ck = {target_output_str, target_input_str}; CostPtrList cl; cl.push_back(cost); @@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - std::function recursive = - [&](size_t k, double computation, double memory, double communication, double communication_without_para) { + std::function recursive = + [&](size_t k, double computation, double memory, double communication, double communication_without_para, + double communication_forward) { if (k == edges.size()) { auto decision = std::make_shared(selected_cost_list); CostPtr new_cost = std::make_shared(computation, communication); @@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; new_cost->decision_ptr_ = decision; result.push_back(new_cost); return; @@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr selected_cost_list[k] = c; recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, communication + c->communication_cost_, - communication_without_para + c->communication_without_parameter_); + communication_without_para + c->communication_without_parameter_, + communication_forward + c->communication_forward_); } }; - recursive(0, 0.0, 0.0, 0.0, 0.0); - SimplifyForDreasingCommunicationWithPartialPara(&result); + recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); + Simplify(&result); return result; } @@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; double communication = left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; + double communication_forward = + left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; double communication_without_para = left_cost->communication_without_parameter_ + middle_cost->communication_without_parameter_ + right_cost->communication_without_parameter_; @@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); cost->memory_with_reuse_ = memory_cost; + cost->communication_forward_ = communication_forward; ret_cost_list->emplace_back(std::move(cost)); } } @@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); } - SimplifyForDreasingCommunicationWithPartialPara(&result); + Simplify(&result); return result; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 501a983a957f3381d07ff72f39d1017498eb626a..1255d79bdc9e20b6489aa5a560d9e39848a460cc 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; +bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; +int32_t RUN_PHASE = DEFAULT_RUN_PHASE; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); @@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { } else { MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; } + + // MULTI_SUBGRAPHS + auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); + MULTI_SUBGRAPHS = multi_subgraphs; + if (MULTI_SUBGRAPHS) { + MS_LOG(INFO) << "multi_subgraphs: true."; + } else { + MS_LOG(INFO) << "multi_subgraphs: false."; + } + + // RUN_PHASE + auto phase = CostModelContext::GetInstance()->run_phase(); + if (phase != 0 && phase != 1) { + MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; + } + RUN_PHASE = phase; + MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; } void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { @@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: MS_EXCEPTION_IF_NULL(cost3); double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; - double commmunication = - cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; + double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; + double communication_forward = + cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; double communication_without_para = cost1->communication_without_parameter_ + cost2->communication_without_parameter_ + cost3->communication_without_parameter_; auto decision = std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); - auto cost = std::make_shared(computation, commmunication, decision); + auto cost = std::make_shared(computation, communication, decision); MS_EXCEPTION_IF_NULL(cost); cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para); + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); cost->memory_with_reuse_ = memory; + cost->communication_forward_ = communication_forward; ret.push_back(cost); } } @@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: } } - SimplifyForDreasingCommunicationWithPartialPara(&ret); + Simplify(&ret); return ret; } @@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { cost1->communication_without_parameter_ + COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; + new_cost->communication_forward_ = cost1->communication_forward_; ret.push_back(new_cost); } } - SimplifyForDreasingCommunicationWithPartialPara(&ret); + Simplify(&ret); return ret; } -CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { + // Select the cost with minimum inference time. Currently, the inference time is modeled as = + // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ + if (cost_list.empty()) { + MS_LOG(ERROR) << "Final cost list is null."; + return nullptr; + } CostPtrList after_mem_filter; - // Filter out the valid costs + double minimum_memory = DBL_MAX; + // Filter out the valid costs. for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); + } else if (a_cost->memory_with_reuse_ < minimum_memory) { + minimum_memory = a_cost->memory_with_reuse_; } } + if (after_mem_filter.empty()) { + MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory + << ", the memory capacity is: " << memory << "."; + return nullptr; + } + // Init the returned value with first cost. + CostPtr ret = after_mem_filter[0]; - std::function LocalCompare = [&](CostPtr init, const CostPtr &cost_x) { - MS_EXCEPTION_IF_NULL(cost_x); - if (init == nullptr || cost_x->computation_cost_ < memory) { - init = cost_x; + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; + MS_LOG(INFO) << "Cost 0: " + << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ + << ", communication_forward_: " << ret->communication_forward_ + << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ + << ", communication_cost_: " << ret->communication_cost_ + << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; + MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; + for (size_t i = 1; i < after_mem_filter.size(); ++i) { + MS_EXCEPTION_IF_NULL(after_mem_filter[i]); + MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ + << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ + << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ + << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ + << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ + << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ + << "."; + auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + + costmodel_beta_ * after_mem_filter[i]->communication_forward_; + MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; + if (minimum > tmp) { + minimum = tmp; + ret = after_mem_filter[i]; + MS_LOG(INFO) << "Selected: " << i; } - return init; - }; - CostPtr ret = nullptr; - return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); + } + return ret; } CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { @@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() { }); if (alive_ops.size() > 2) { - return SearchStrategyForMultiNodeFinalGraph(alive_ops); + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + return SearchStrategyForMultiNodeFinalGraph(alive_ops); + } else { + // inference phase + MS_LOG(EXCEPTION) + << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; + } } else if (alive_ops.size() == 1) { MS_LOG(INFO) << "There are 1 single node in the final graph."; OperatorInfoPtr u = alive_ops[0]; auto cost_list = CreateFinalSingleCostList(u); - auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + CostPtr cost = nullptr; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + } else { + // inference phase + cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); + } if (cost == nullptr) { MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; return FAILED; @@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() { auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); all_list.push_back(cost_list); } - auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); + CostPtrList selected_cost_list; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); + } else { + // inference phase + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " + "phase is not supported."; + } for (size_t k = 0; k < selected_cost_list.size(); ++k) { auto selected_cost = selected_cost_list[k]; if (selected_cost == nullptr) { @@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() { auto e = u->GetAliveSuccEdges()[0]; MS_EXCEPTION_IF_NULL(e); auto cost_list = CreateFinalCostList(u, e, v); - auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + CostPtr cost = nullptr; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + } else { + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " + "phase is not supported."; + } if (cost == nullptr) { MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; return FAILED; @@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; double communication = op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; + double communication_forward = + op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; double communication_without_para = op_cost->communication_without_parameter_ + edge_cost->communication_without_parameter_ + tar_cost->communication_without_parameter_; @@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; MS_EXCEPTION_IF_NULL(tar_cost_list_new); tar_cost_list_new->emplace_back(std::move(new_cost)); } @@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); } - SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new); + Simplify(&tar_clist_new); // Set the new costlist w.r.t the strategy tar_stra_cost->cost_list = tar_clist_new; if ((!valid) && (!tar_clist_new.empty())) { @@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; double communication = contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; + double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + + tar_cost->communication_forward_; double communication_without_para = contract_op_cost->communication_without_parameter_ + edge_cost->communication_without_parameter_ + tar_cost->communication_without_parameter_; @@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str new_cost->communication_with_partial_para_ = communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; tar_cost_list_new->emplace_back(std::move(new_cost)); } } @@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); } - SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new); + Simplify(&tar_clist_new); // Set the new costlist w.r.t the strategy tar_stra_cost->cost_list = tar_clist_new; if ((!valid) && (!tar_clist_new.empty())) { @@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; + double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + + left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; double new_commu_without = elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; @@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, new_cost->communication_with_partial_para_ = new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); new_cost->memory_with_reuse_ = new_memory; + new_cost->communication_forward_ = new_commu_forward; left_node_clist_new->emplace_back(std::move(new_cost)); } } @@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, &left_node_clist_new); } } - SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new); + Simplify(&left_node_clist_new); // Set the new costlist w.r.t the strategy left_node_stra_cost->cost_list = left_node_clist_new; if ((!valid) && (!left_node_clist_new.empty())) { @@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n double computation_cost = merged_node_cost->computation_cost_, memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, - commu_without = merged_node_cost->communication_without_parameter_; + commu_without = merged_node_cost->communication_without_parameter_, + commu_forward = merged_node_cost->communication_forward_; for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); if (i == 0) { computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; + commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; commu_without += succ_edges_costs[i]->communication_without_parameter_ + succ_nodes_costs[i]->communication_without_parameter_; } else { computation_cost += succ_edges_costs[i]->computation_cost_; memory_cost += succ_edges_costs[i]->memory_with_reuse_; commu_cost += succ_edges_costs[i]->communication_cost_; + commu_forward += succ_edges_costs[i]->communication_forward_; commu_without += succ_edges_costs[i]->communication_without_parameter_; } } @@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n new_cost->communication_without_parameter_ = commu_without; new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); new_cost->memory_with_reuse_ = memory_cost; + new_cost->communication_forward_ = commu_forward; first_succ_node_clist_new->emplace_back(std::move(new_cost)); } } @@ -1220,7 +1318,7 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, merged_op_stra, merged_op_clist, &first_succ_node_clist_new); } - SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new); + Simplify(&first_succ_node_clist_new); // Set the new costlist w.r.t the strategy first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; if ((!valid) && (!first_succ_node_clist_new.empty())) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index 31de9f4456f0d39d37351c0fa6a8c07d160ae03f..5077459695c50dea4bde713ae0c45317142e1e08 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -45,6 +45,9 @@ namespace parallel { #define DEFAULT_FULLY_USE_DEVICES true #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false #define DEFAULT_IS_MULTI_SUBGRAPHS false +#define DEFAULT_RUN_PHASE 0 +#define TRAINING_PHASE 0 +#define INFERENCE_PHASE 1 class CostGraph; using CostGraphPtr = std::shared_ptr; @@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; extern bool FULLY_USE_DEVICES; extern bool ELEMENTWISE_OP_STRA_FOLLOW; +extern bool MULTI_SUBGRAPHS; +extern int32_t RUN_PHASE; class CostGraph { // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have @@ -98,7 +103,7 @@ class CostGraph { CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); - CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory); + CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); diff --git a/mindspore/ccsrc/parallel/costmodel_context.cc b/mindspore/ccsrc/parallel/costmodel_context.cc index 591fa737aa1cd9481ff111ec49abc51305a40e09..92aff295575a727ecdd2c26403ce79721a4b0386 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/parallel/costmodel_context.cc @@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() { costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; + run_phase_ = DEFAULT_RUN_PHASE; costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; @@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { elementwise_stra_follow_ = elementwise_follow; } + +void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/parallel/costmodel_context.h index ebb0d00008ad382490e60d42c29246ed5d694803..bddab683ff860e5c99327a67f469623a3310ba56 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.h +++ b/mindspore/ccsrc/parallel/costmodel_context.h @@ -113,6 +113,9 @@ class CostModelContext { void set_elementwise_stra_follow(bool); bool elementwise_stra_follow() const { return elementwise_stra_follow_; } + void set_run_phase(int32_t); + int32_t run_phase() const { return run_phase_; } + private: CostModelContext(); static std::shared_ptr cm_context_inst_; @@ -141,8 +144,11 @@ class CostModelContext { // COST_MODEL_COMMUNI_BIAS double costmodel_communi_bias_; + // MULTI_SUBGRAPHS bool is_multi_subgraphs_; + int32_t run_phase_; // 0: 'training', 1: 'inference' + int32_t costmodel_allreduce_fusion_algorithm_; int32_t costmodel_allreduce_fusion_times_; diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 3f55efb66c7327e530deecb9750b90d6a1ae6e72..7752148b7d1e371357cb737b634368754d9116d2 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & << ", communication_with_partial_para_: " << result->communication_with_partial_para_; // refine communication cost calculation for practice RefineForPracticalCost(result, false); + result->communication_forward_ = result->communication_without_parameter_; std::shared_ptr swc = std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 8074f2a32ef3d2c3d23aa670013861080ea20a84..b1db6cda4db44219948fc020522c8d1520b3a775 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { BreakingTiesForPerferringDataParallel(strategy, result); // refine communication cost calculation for practice RefineForPracticalCost(result, false); + result->communication_forward_ = result->communication_without_parameter_; std::shared_ptr swc = std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index e7800909c5f45e5e4455e789188390e213d594b2..d1f46108bbc2b581317192f8ddf7aa4874e54cc0 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -69,16 +69,16 @@ class TensorRedistribution { RankList dev_list_; OperatorList operator_list_; bool reshape_flag_; - // communication cost + // communication cost, which is the sum of forward communication cost and backward communication cost double comm_cost_; // forward communication cost double forward_comm_cost_; // backward communication cost double backward_comm_cost_; // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the - // inputs. + // inputs. This is calculated ONLY for forward phase. double computation_cost_; - // memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is + // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is // calculated by the outputs. double memory_cost_; bool construct_op_flag_; diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index f1feedb64ff3315711bf1a874e179f981e80701f..dc59d117c5f26717e81510b653f46a1562791608 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) { "Get the parameter cost_model_communi_bias of the DP algorithm.") .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") + .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") + .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, "Set the parameter gradient AllReduce fusion algorithm.") .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index 2790aed855d8737171796ae03351cf006b3e9095..dda68e4f2d99442e1feb9e2101a6054e297f18c2 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -239,6 +239,33 @@ class _CostModelContext: raise ValueError("Context handle is none in context!!!") return self._context_handle.get_multi_subgraphs() + def set_run_phase(self, phase): + """ + Set the flag of running phase: training (0) or inference (1) + + Args: + phase (int): A parameter indicating which phase is running. + + Raises: + ValueError: If context handle is none, or phase is not in {0, 1}. + """ + if self._context_handle is None: + raise ValueError("Context handle is none in context!!!") + if phase not in (0, 1): + raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase)) + self._context_handle.set_run_phase(phase) + + def get_run_phase(self): + """ + Get the flag of running phase. + + Raises: + ValueError: If context handle is none. + """ + if self._context_handle is None: + raise ValueError("Context handle is none in context!!!") + return self._context_handle.get_run_phase() + def set_costmodel_allreduce_fusion_algorithm(self, algorithm): """ Set costmodel allreduce fusion algorithm. @@ -453,6 +480,7 @@ set_cost_model_context_func_map = { "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, "multi_subgraphs": cost_model_context().set_multi_subgraphs, + "run_phase": cost_model_context().set_run_phase, "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, @@ -473,7 +501,8 @@ get_cost_model_context_func_map = { "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, - "multi_subgraphs": cost_model_context().get_multi_subgraphs(), + "multi_subgraphs": cost_model_context().get_multi_subgraphs, + "run_phase": cost_model_context().get_run_phase, "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, @@ -488,7 +517,7 @@ get_cost_model_context_func_map = { @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, - multi_subgraphs=bool, + multi_subgraphs=bool, run_phase=int, costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, costmodel_allreduce_fusion_allreduce_inherent_time=float, @@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs): costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. + run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. 0: bypass allreduce fusion; 1: only use backward computation time to group allreduce; diff --git a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc index 81b017a28dc7d169250c39155d469d3715c1f211..78d05c7235e8bc5273d841a4dc719326d48fb40d 100644 --- a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc @@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) { ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); - cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory()); + cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory()); } TEST_F(TestCostGraph, test_EliminationOp) { diff --git a/tests/ut/python/parallel/__init__.py b/tests/ut/python/parallel/__init__.py index b26962bc3af9f7f0df438d3d65220f7366e828b1..da0f047f70fbf8fc33bd8054c09c4036d3f24acb 100644 --- a/tests/ut/python/parallel/__init__.py +++ b/tests/ut/python/parallel/__init__.py @@ -14,15 +14,21 @@ import mindspore.context as context from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.parallel._cost_model_context import reset_cost_model_context +from mindspore.parallel.algo_parameter_config import reset_algo_parameters from mindspore.parallel._utils import _reset_op_id def setup_module(module): auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + reset_cost_model_context() + reset_algo_parameters() _reset_op_id() def teardown_module(): context.reset_auto_parallel_context() + reset_cost_model_context() + reset_algo_parameters() _reset_op_id() diff --git a/tests/ut/python/parallel/test_auto_parallel_inference.py b/tests/ut/python/parallel/test_auto_parallel_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3aad5f62c6c80a4712efce623e5066de95b1075d --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_inference.py @@ -0,0 +1,36 @@ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, context +from mindspore.ops import operations as P +from mindspore.nn import WithLossCell, TrainOneStepCell +from mindspore.nn import Momentum +from mindspore.parallel._cost_model_context import set_cost_model_context + +class Net(nn.Cell): + def __init__(self, input_ch, out_ch): + super(Net, self).__init__() + self.dense = nn.Dense(input_ch, out_ch) + self.relu = P.ReLU() + + def construct(self, x): + x = self.dense(x) + x = self.relu(x) + return x + +def test_inference_phase(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + set_cost_model_context(run_phase=1) + + net = Net(512, 128) + predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001) + label = Tensor(np.ones([64, 128]).astype(np.float32)) + + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + train_network.set_train() + + output = train_network(predict, label) \ No newline at end of file