diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h index bcd5f78532f93eaf85b5556266fb4b2eb51b9926..a60cbc04287806740403e654bdead8997a66cdfa 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h @@ -79,6 +79,8 @@ class StrategyWithCost { public: StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} + StrategyWithCost(StrategyPtr strategy, CostPtrList c_list) + : strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {} StrategyWithCost(const StrategyWithCost &swc) = delete; StrategyWithCost(StrategyWithCost &&swc) @@ -99,6 +101,7 @@ enum DecisionType { EDGE_ELIMINATION, MERGE_ELIMINATION, CONTRACT_ELIMINATION, + SOURCE_ELIMINATION, TRIANGLE_ELIMINATION, STAR_ELIMINATION, FINAL_TYPE, @@ -199,6 +202,38 @@ struct ContractEliminationDecision : public Decision { MS_DECLARE_PARENT(ContractEliminationDecision, Decision); }; +/* 'SourceEliminationDecision' is for the source Elimination in DP algorithm: + * 1 1,5 + * / \ // \\ + * / \ // \\ + * / \ // \\ + * / \ // \\ + * 2 <- 5 -> 3 ==> 2 3 + * \ / \ / + * \ / \ / + * \ / \ / + * 4 4 + * + * In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and + * no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into + * '1' new edges are generated to replace the old ones incident to '1' and '5'. + * + */ +struct SourceEliminationDecision : public Decision { + SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c) + : op1_strategy_(std::move(op1_stra)), + op1_cost_(std::move(op1_c)), + op2_strategy_(std::move(op2_stra)), + op2_cost_(std::move(op2_c)) { + type_ = DecisionType::SOURCE_ELIMINATION; + } + StrategyPtr op1_strategy_; + CostPtr op1_cost_; + StrategyPtr op2_strategy_; + CostPtr op2_cost_; + MS_DECLARE_PARENT(SourceEliminationDecision, Decision); +}; + /* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: * * u @@ -296,6 +331,7 @@ using OpEliminationDecisionPtr = std::shared_ptr; using EdgeEliminationDecisionPtr = std::shared_ptr; using MergeEliminationDecisionPtr = std::shared_ptr; using ContractEliminationDecisionPtr = std::shared_ptr; +using SourceEliminationDecisionPtr = std::shared_ptr; using TriangleEliminationDecisionPtr = std::shared_ptr; using StarEliminationDecisionPtr = std::shared_ptr; using FinalDecisionPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc index 9408596111df2af8692e127fd4056e8ff08b652b..b49ca05f3e1668bd672cb0ef03c6b53ee5796fef 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -42,66 +42,76 @@ Status GetStrategy(const CostGraphPtr &graph) { auto elimi = std::make_shared(n_edge, l_edge, node, r_edge); eliminations.emplace_back(std::move(elimi)); } - auto edges = graph->CheckEdgeElimination(); - if ((!flag) && (!edges.empty())) { - // Applying the Edge Elimination - flag = true; - auto n_edge = graph->EliminationEdges(edges); - auto elimi = std::make_shared(n_edge, edges); - eliminations.emplace_back(std::move(elimi)); + if (!flag) { + auto edges = graph->CheckEdgeElimination(); + if (!edges.empty()) { + // Applying the Edge Elimination + flag = true; + auto n_edge = graph->EliminationEdges(edges); + auto elimi = std::make_shared(n_edge, edges); + eliminations.emplace_back(std::move(elimi)); + } } - auto merge_node = graph->CheckMergeElimination(); - if ((!flag) && (merge_node != nullptr)) { - // Applying the Merge Elimination - flag = true; - auto succ_edge = merge_node->GetAliveSuccEdges()[0]; - auto target_node = graph->EliminationMerge(merge_node); - auto elimi = std::make_shared(merge_node, succ_edge, target_node); - eliminations.emplace_back(std::move(elimi)); + if (!flag) { + auto merge_node = graph->CheckMergeElimination(); + if (merge_node != nullptr) { + // Applying the Merge Elimination + flag = true; + auto succ_edge = merge_node->GetAliveSuccEdges()[0]; + auto target_node = graph->EliminationMerge(merge_node); + auto elimi = std::make_shared(merge_node, succ_edge, target_node); + eliminations.emplace_back(std::move(elimi)); + } } - auto contracted_node = graph->CheckContractElimination(); - if ((!flag) && (contracted_node != nullptr)) { - // Applying the Contract Elimination - flag = true; - auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; - auto target_node = graph->EliminationContract(contracted_node); - auto elimi = std::make_shared(target_node, prev_edge, contracted_node); - eliminations.emplace_back(std::move(elimi)); + if (!flag) { + auto contracted_node = graph->CheckContractElimination(); + if ((contracted_node != nullptr)) { + // Applying the Contract Elimination + flag = true; + auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; + auto target_node = graph->EliminationContract(contracted_node); + auto elimi = std::make_shared(target_node, prev_edge, contracted_node); + eliminations.emplace_back(std::move(elimi)); + } } - auto triangle_pair = graph->CheckTriangleElimination(); - if ((!flag) && (triangle_pair.first != nullptr)) { - // Applying the Triangle Elimination - flag = true; - auto eliminated_node = triangle_pair.first; - auto l_r_edge = triangle_pair.second; + if (!flag) { + auto triangle_pair = graph->CheckTriangleElimination(); + if (triangle_pair.first != nullptr) { + // Applying the Triangle Elimination + flag = true; + auto eliminated_node = triangle_pair.first; + auto l_r_edge = triangle_pair.second; - auto left_node = l_r_edge->prev_operator(); - auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; - auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(left_edge); - if (left_edge->next_operator() != left_node) { - auto tmp = left_edge; - left_edge = right_edge; - right_edge = tmp; + auto left_node = l_r_edge->prev_operator(); + auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; + auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(left_edge); + if (left_edge->next_operator() != left_node) { + auto tmp = left_edge; + left_edge = right_edge; + right_edge = tmp; + } + auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); + auto right_node = l_r_edge->next_operator(); + auto elimi = + std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); + eliminations.emplace_back(std::move(elimi)); } - auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); - auto right_node = l_r_edge->next_operator(); - auto elimi = - std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); - eliminations.emplace_back(std::move(elimi)); } - auto star_center = graph->CheckStarElimination(); - if ((!flag) && (star_center != nullptr)) { - // Applying the Star Elimination - flag = true; - auto succ_edges = graph->EliminationStar(star_center); - std::vector succ_nodes; - for (size_t i = 0; i < succ_edges.size(); ++i) { - MS_EXCEPTION_IF_NULL(succ_edges[i]); - succ_nodes.push_back(succ_edges[i]->next_operator()); + if (!flag) { + auto star_center = graph->CheckStarElimination(); + if (star_center != nullptr) { + // Applying the Star Elimination + flag = true; + auto succ_edges = graph->EliminationStar(star_center); + std::vector succ_nodes; + for (size_t i = 0; i < succ_edges.size(); ++i) { + MS_EXCEPTION_IF_NULL(succ_edges[i]); + succ_nodes.push_back(succ_edges[i]->next_operator()); + } + auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); + eliminations.emplace_back(std::move(elimi)); } - auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); - eliminations.emplace_back(std::move(elimi)); } } diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h index 4281bbd06a476ad7a7229a7de72ca0a634743632..ec131e519f6746b48d5da22657c7e1055c707c10 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h @@ -42,7 +42,7 @@ namespace parallel { // the operators' strategies can be all determined. struct Elimination : public Base { - enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; + enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR }; Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} EdgePtr new_edge_; @@ -100,6 +100,26 @@ struct ContractElimination : public Elimination { MS_DECLARE_PARENT(ContractElimination, Elimination); }; +// Source Elimination +struct SourceElimination : public Elimination { + SourceElimination(OperatorInfoPtr p_source, std::vector p_succ_edges, std::vector p_new_succ_edges, + OperatorInfoPtr s_source, std::vector s_succ_edges, std::vector s_new_succ_edges) + : Elimination(nullptr, Elimination::EliminationType::SOURCE), + primary_source_(std::move(p_source)), + primary_succ_edges_(std::move(p_succ_edges)), + primary_new_succ_edges_(std::move(p_new_succ_edges)), + secondary_source_(std::move(s_source)), + secondary_succ_edges_(std::move(s_succ_edges)), + secondary_new_succ_edges_(std::move(s_new_succ_edges)) {} + OperatorInfoPtr primary_source_; + std::vector primary_succ_edges_; + std::vector primary_new_succ_edges_; + OperatorInfoPtr secondary_source_; + std::vector secondary_succ_edges_; + std::vector secondary_new_succ_edges_; + MS_DECLARE_PARENT(SourceElimination, Elimination); +}; + // Triangle Elimination struct TriangleElimination : public Elimination { TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, @@ -138,6 +158,7 @@ using OpEliminationPtr = std::shared_ptr; using EdgeEliminationPtr = std::shared_ptr; using MergeEliminationPtr = std::shared_ptr; using ContractEliminationPtr = std::shared_ptr; +using SourceEliminationPtr = std::shared_ptr; using TriangleEliminationPtr = std::shared_ptr; using StarEliminationPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc index e3f1de72077ddb47191baf97295994f71b5ab48f..59be4918527da74f77c7e044ab9fb05af3c4703b 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -320,5 +320,17 @@ Status Edge::CalculateMemoryCostForInference() { } return SUCCESS; } + +void Edge::SetCostMapAndInputOutput(std::map &cost_map) { + cost_map_ = cost_map; + pre_op_output_.clear(); + next_op_input_.clear(); + + for (auto &key_value : cost_map_) { + auto &key_pair = key_value.first; + pre_op_output_.emplace_back(std::pair>(key_pair.first, {})); + next_op_input_.emplace_back(std::pair>(key_pair.second, {})); + } +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h index 3fffd1b86d2ff8fdb247d0173b79d88a2002ce93..d67b7e714a4f9c04e1ac0cef5253c4a9b4390b86 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -80,6 +80,8 @@ class Edge { std::string edge_name() const { return edge_name_; } // Init cost_map_: for each output layout and input layout, calculate the cost Status InitEdgeCost(); + std::map GetCostMap() { return cost_map_; } + void SetCostMapAndInputOutput(std::map &); // For two operators u--->v, given the output tensor layout of u, // and the input tensor layout of v, return the redistribution cost, // and the op_list to carry out the redistribution. diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 1c1fc3a700ba386131b0698212147089ae0464d0..d30d5e33ac8fc805e468878bc41c37cd3aca7d2f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -794,6 +794,191 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const { return nullptr; } +std::pair CostGraph::CheckSourceElimination() const { + size_t source_count = 0; + std::vector op_vector(2, nullptr); + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0; + if (bool_test) { + op_vector[source_count++] = op; + if (source_count == 2) { + return std::make_pair(op_vector[0], op_vector[1]); + } + } + } + return std::make_pair(nullptr, nullptr); +} + +void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist, + StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist, + CostPtrList *op1_new_clist) { + for (auto &op1_cost : op1_old_clist) { + for (auto &op2_cost : op2_old_clist) { + double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_; + double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_; + double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_; + double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_; + double communication_without_para = + op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; + auto decision = std::make_shared(op1_old_stra, op1_cost, op2_old_stra, op2_cost); + auto new_cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + MS_EXCEPTION_IF_NULL(op1_new_clist); + op1_new_clist->emplace_back(std::move(new_cost)); + } + } +} + +std::pair>, std::vector>> CostGraph::EliminationSources( + OperatorInfoPtr op1, OperatorInfoPtr op2) { + MS_EXCEPTION_IF_NULL(op1); + MS_EXCEPTION_IF_NULL(op2); + MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name(); + + auto op1_old_succ_edges = op1->GetAliveSuccEdges(); + std::vector>>> op1_edges_reorganised_cost( + op1_old_succ_edges.size()); + std::vector> op1_new_edges_cost(op1_old_succ_edges.size()); + std::vector> op1_new_succ_edges(op1_old_succ_edges.size()); + + auto op2_old_succ_edges = op2->GetAliveSuccEdges(); + std::vector>>> op2_edges_reorganised_cost( + op2_old_succ_edges.size()); + std::vector> op2_new_edges_cost(op2_old_succ_edges.size()); + std::vector> op2_new_succ_edges(op2_old_succ_edges.size()); + + // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost' + for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { + const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap(); + std::map>> from_tocost; + for (const auto &key_value : op1_cost_map) { + const auto &from_to_strategies = key_value.first; + const auto &costlist = key_value.second; + from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist)); + } + op1_edges_reorganised_cost[i] = from_tocost; + } + + for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { + const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap(); + std::map>> from_tocost; + for (const auto &key_value : op2_cost_map) { + const auto &from_to_strategies = key_value.first; + const auto &costlist = key_value.second; + from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist)); + } + op2_edges_reorganised_cost[i] = from_tocost; + } + + // Merge op2 into op1 + const auto &op1_old_stra_cost = op1->GetStrategyCost(); + const auto &op2_old_stra_cost = op2->GetStrategyCost(); + std::vector> op1_new_stra_cost; + + for (auto &op1_stra_cost : op1_old_stra_cost) { + auto op1_old_stra = op1_stra_cost->strategy_ptr; + auto op1_old_costlist = op1_stra_cost->cost_list; + + for (auto &op2_stra_cost : op2_old_stra_cost) { + auto op2_stra = op2_stra_cost->strategy_ptr; + auto op2_costlist = op2_stra_cost->cost_list; + + StrategyPtr op1_new_stra = std::make_shared(*op1_old_stra); + op1_new_stra->CoverStrategy(op2_stra); + CostPtrList op1_new_costlist; + // Calculate new cost for 'op1_new_costlist' + CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist); + std::shared_ptr swc = std::make_shared(op1_new_stra, op1_new_costlist); + op1_new_stra_cost.emplace_back(swc); + + // Set cost for new successive edges of op1 and op2 + for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { + auto &from_tocost = op1_edges_reorganised_cost[i]; + auto &to_cost = from_tocost[op1_old_stra]; + auto &new_cost_map = op1_new_edges_cost[i]; + for (auto &stra_costlit : to_cost) { + auto &to_strategy = stra_costlit.first; + auto &edge_costlist = stra_costlit.second; + CostPtrKey new_key = {op1_new_stra, to_strategy}; + new_cost_map[new_key] = edge_costlist; + } + } + for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { + auto &from_tocost = op2_edges_reorganised_cost[i]; + auto &to_cost = from_tocost[op2_stra]; + auto &new_cost_map = op2_new_edges_cost[i]; + for (auto &stra_costlist : to_cost) { + auto &to_strategy = stra_costlist.first; + auto &edge_costlist = stra_costlist.second; + CostPtrKey new_key = {op1_new_stra, to_strategy}; + new_cost_map[new_key] = edge_costlist; + } + } + } + } + op1->SetStrategyCost(op1_new_stra_cost); + op2->SetNotAlive(); + + // Update the edges incident to op1, and edges incident to op2 + for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { + auto &new_cost_map = op1_new_edges_cost[i]; + auto &ith_edge = op1_old_succ_edges[i]; + + std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); + std::shared_ptr new_edge; + if (ith_edge->is_combined()) { + std::vector output_indexs, input_indexs; + output_indexs = ith_edge->prev_op_output_indexs(); + input_indexs = ith_edge->next_op_input_indexs(); + new_edge = + std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); + } else { + size_t output_index, input_index; + output_index = ith_edge->prev_op_output_index(); + input_index = ith_edge->next_op_input_index(); + new_edge = + std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); + } + new_edge->SetCostMapAndInputOutput(new_cost_map); + // replace the old successive edges with the new ones. + op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); + ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); + op1_new_succ_edges[i] = new_edge; + } + for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { + auto &new_cost_map = op2_new_edges_cost[i]; + auto &ith_edge = op2_old_succ_edges[i]; + const auto &destination = ith_edge->next_operator(); + + std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); + std::shared_ptr new_edge; + if (ith_edge->is_combined()) { + std::vector output_indexs, input_indexs; + output_indexs = ith_edge->prev_op_output_indexs(); + input_indexs = ith_edge->next_op_input_indexs(); + new_edge = std::make_shared(new_edge_name, op1, destination, output_indexs, input_indexs, true); + } else { + size_t output_index, input_index; + output_index = ith_edge->prev_op_output_index(); + input_index = ith_edge->next_op_input_index(); + new_edge = std::make_shared(new_edge_name, op1, destination, output_index, input_index, false); + } + new_edge->SetCostMapAndInputOutput(new_cost_map); + // replace the old successive edges with the new ones. + destination->ReplacePreEdge(op2, new_edge); + op1->AddSuccEdge(new_edge); + op2_new_succ_edges[i] = new_edge; + } + MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; + return {op1_new_succ_edges, op2_new_succ_edges}; +} + // Check the graph whether a TriangleElimination can be performed std::pair> CostGraph::CheckTriangleElimination() const { for (auto &op : ops_) { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 3ec328352c9acc4aa07d896a76d317c6d03530f9..5ef979f4dc5b6a808fe08e12c15e7bab76078051 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -180,6 +180,14 @@ class CostGraph { void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, const StrategyPtr &, const CostPtrList &, std::vector, CostPtrList &, CostPtrList &, CostPtrList *); + // Return . we merge 'op2' into 'op1' + std::pair CheckSourceElimination() const; + void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &, + CostPtrList *); + // We merge 'op2' into op1. The returned value are ''. 'Edges1' are newly updated edges for 'op1', + // 'Edges2' are newly updated edges for 'op2'. + std::pair>, std::vector>> EliminationSources( + OperatorInfoPtr op1, OperatorInfoPtr op2); // Calculate memory cost for training phase or inference phase. Status CalculateMemoryCost(); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 3dd47b1de69361d057c4ed1e070246b1cc3339c5..6121a2d34a24d342b589a7ae232e16d09fd2de77 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -1330,5 +1330,9 @@ void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { PrintStrategy(s_strategy); } } + +void OperatorInfo::SetStrategyCost(const std::vector> &stra_cost) { + strategy_cost_ = stra_cost; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 49e44ef347c0161b551317b0bfd7b3ab33c533d5..4f5f6d4ab889a0ee074f667df0fa4db251630a4b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -97,6 +97,7 @@ class OperatorInfo { // is checked Status SetCostUnderStrategyBase(const StrategyPtr &strategy); std::vector> GetStrategyCost() { return strategy_cost_; } + void SetStrategyCost(const std::vector> &); // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory // at the end of forward phase. diff --git a/mindspore/ccsrc/frontend/parallel/strategy.h b/mindspore/ccsrc/frontend/parallel/strategy.h index dfcf56de2ff5342975e0df3ab1f4ebf9a442cc54..1d2877e061fd2614e95d68ee1f7b38aa90c36a96 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy.h +++ b/mindspore/ccsrc/frontend/parallel/strategy.h @@ -36,7 +36,19 @@ using StrategyPtr = std::shared_ptr; class Strategy { public: - Strategy(int32_t stage, std::vector inputs) : stage_(stage), inputs_(std::move(inputs)) {} + Strategy(int32_t stage, std::vector inputs) + : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {} + + Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) { + inputs_ = another_stra.GetInputDim(); + internal_size_ = another_stra.GetInternalSize(); + if (internal_size_ != 0) { + internal_stragies_ = another_stra.GetInternalStrategies(); + } else { + internal_stragies_ = {}; + } + } + ~Strategy() = default; size_t GetInputNumber() const { return inputs_.size(); } std::vector GetInputDim() const { return inputs_; } @@ -47,7 +59,10 @@ class Strategy { } } void ResetInputs(const std::vector &input) { inputs_ = input; } + std::vector GetInternalStrategies() const { return internal_stragies_; } + size_t GetInternalSize() const { return internal_size_; } + // TODO(Xiaoda): need fix for adapting 'CoverStrategy' bool IsEqual(const StrategyPtr &another_stra) { if (another_stra == nullptr) { return false; @@ -58,11 +73,19 @@ class Strategy { return true; } + // Include 'another_stra' into this strategy + void CoverStrategy(const StrategyPtr &another_stra) { + internal_stragies_.push_back(another_stra); + internal_size_++; + } + private: const int32_t stage_; // The size of Dimensions must equal to inputs_ tensor dimension. std::vector inputs_; + size_t internal_size_ = 0; + std::vector internal_stragies_; }; inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { diff --git a/tests/ut/python/parallel/test_auto_parallel_double_sources.py b/tests/ut/python/parallel/test_auto_parallel_double_sources.py new file mode 100644 index 0000000000000000000000000000000000000000..188c962f26a88961f3c8cf445851e2786b351f56 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_double_sources.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y, z, w, a): + predict = self.network(x, y, z, w, a) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y, z, w, a): + return C.grad_all(self.network)(x, y, z, w, a) + + # model_parallel test + + +def test_double_source_graph(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + self.matmul3 = P.MatMul() + self.matmul4 = P.MatMul() + self.matmul5 = P.MatMul() + + def construct(self, x, y, z, w, a): + m1_result = self.matmul1(x, y) + m2_result = self.matmul2(z, w) + m3_result = self.matmul3(m2_result, m1_result) + m4_result = self.matmul4(m2_result, m1_result) + out = self.matmul5(m3_result, m4_result) + + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([32, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 32]), dtype=ms.float32) + z = Tensor(np.ones([32, 32]), dtype=ms.float32) + w = Tensor(np.ones([32, 32]), dtype=ms.float32) + a = Tensor(np.ones([32, 32]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x, y, z, w, a) + + +def test_double_source_complex_graph(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + self.matmul3 = P.MatMul() + self.matmul4 = P.MatMul() + self.matmul5 = P.MatMul() + self.matmul6 = P.MatMul() + + def construct(self, x, y, z, w, a): + m1_result = self.matmul1(x, y) + m6_result = self.matmul6(m1_result, a) + m2_result = self.matmul2(z, w) + m3_result = self.matmul3(m2_result, m6_result) + m4_result = self.matmul4(m2_result, m1_result) + out = self.matmul5(m3_result, m4_result) + + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([32, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 32]), dtype=ms.float32) + z = Tensor(np.ones([32, 32]), dtype=ms.float32) + w = Tensor(np.ones([32, 32]), dtype=ms.float32) + a = Tensor(np.ones([32, 32]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x, y, z, w, a)