提交 0ac50a19 编写于 作者: X Xiaoda Zhang

Model the memory cost in auto-parallel. It is calculated by the output of...

Model the memory cost in auto-parallel. It is calculated by the output of operators, plus the parameters. Additionally, modify the graph-operations in auto_parallel to include memory_cost.
上级 c9fba7f0
......@@ -207,15 +207,13 @@ struct ContractEliminationDecision : public Decision {
*/
struct TriangleEliminationDecision : public Decision {
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost)
StrategyPtr left_stra, CostPtr l_node_cost)
: eliminated_op_strategy_(std::move(elimi_stra)),
eliminated_op_cost_(std::move(elimi_op_cost)),
left_edge_cost_(std::move(l_edge_cost)),
right_edge_cost_(std::move(r_edge_cost)),
left_node_strategy_(std::move(left_stra)),
left_node_cost_(std::move(l_node_cost)),
right_node_strategy_(std::move(right_stra)),
right_node_cost_(std::move(r_node_cost)) {
left_node_cost_(std::move(l_node_cost)) {
type_ = DecisionType::TRIANGLE_ELIMINATION;
}
......@@ -225,8 +223,6 @@ struct TriangleEliminationDecision : public Decision {
CostPtr right_edge_cost_;
StrategyPtr left_node_strategy_;
CostPtr left_node_cost_;
StrategyPtr right_node_strategy_;
CostPtr right_node_cost_;
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
};
......
......@@ -76,7 +76,6 @@ Status GetStrategy(const CostGraphPtr& graph) {
auto l_r_edge = triangle_pair.second;
auto left_node = l_r_edge->prev_operator();
auto right_node = l_r_edge->next_operator();
auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
MS_EXCEPTION_IF_NULL(left_edge);
......@@ -86,8 +85,7 @@ Status GetStrategy(const CostGraphPtr& graph) {
right_edge = tmp;
}
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
auto elimi =
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge);
eliminations.emplace_back(std::move(elimi));
}
auto star_center = graph->CheckStarElimination();
......@@ -183,14 +181,13 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
auto left_edge = elimination->left_edge_;
auto eliminated_node = elimination->eliminated_node_;
auto right_edge = elimination->right_edge_;
auto right_node = elimination->right_node_;
auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
left_edge->set_selected_cost(decision->left_edge_cost_);
right_edge->set_selected_cost(decision->right_edge_cost_);
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
} else if ((*rit)->isa<StarElimination>()) {
auto elimination = (*rit)->cast<StarEliminationPtr>();
......@@ -204,9 +201,11 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
for (size_t i = 0; i < succ_edges.size(); ++i) {
succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]);
}
for (size_t j = 0; j < succ_nodes.size(); ++j) {
succ_nodes[j]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[j], decision->succ_ops_cost_list_[j]);
}
MS_EXCEPTION_IF_NULL(succ_nodes[0]);
MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]);
MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
MS_LOG(INFO) << "Recover starElimination succeeded.";
} else {
MS_LOG(ERROR) << "Unknown Elimination type.";
......
......@@ -102,20 +102,17 @@ struct ContractElimination : public Elimination {
// Triangle Elimination
struct TriangleElimination : public Elimination {
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
OperatorInfoPtr r_node)
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge)
: Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
eliminated_node_(std::move(elim_node)),
left_edge_(std::move(l_edge)),
left_node_(std::move(l_node)),
right_edge_(std::move(r_edge)),
right_node_(std::move(r_node)) {}
right_edge_(std::move(r_edge)) {}
OperatorInfoPtr eliminated_node_;
EdgePtr left_edge_;
OperatorInfoPtr left_node_;
EdgePtr right_edge_;
OperatorInfoPtr right_node_;
MS_DECLARE_PARENT(TriangleElimination, Elimination);
};
......
......@@ -119,6 +119,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
double forward_comm_cost = tensor_redistribution.forward_comm_cost();
double backward_comm_cost = tensor_redistribution.backward_comm_cost();
double computation_cost = tensor_redistribution.computation_cost();
double mem_cost = tensor_redistribution.memory_cost();
// Now AllGather, ReduceScatter, AlltoAll don't support bool type
MS_EXCEPTION_IF_NULL(type);
......@@ -134,6 +135,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
(*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
(*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
(*cost)->memory_with_reuse_ = mem_cost;
return Status::SUCCESS;
}
......@@ -158,8 +160,8 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
std::function<void(size_t, double, double, double)> recursive =
[&](size_t k, double computation, double communication, double communication_without_para) {
std::function<void(size_t, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para) {
if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
......@@ -167,6 +169,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->decision_ptr_ = decision;
result.push_back(new_cost);
return;
......@@ -174,11 +177,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
for (auto& c : all_cost_list[k]) {
MS_EXCEPTION_IF_NULL(c);
selected_cost_list[k] = c;
recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_,
recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
communication + c->communication_cost_,
communication_without_para + c->communication_without_parameter_);
}
};
recursive(0, 0, 0, 0);
recursive(0, 0.0, 0.0, 0.0, 0.0);
SimplifyForDreasingCommunicationWithPartialPara(&result);
return result;
}
......@@ -218,6 +222,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
double communication_without_para = left_cost->communication_without_parameter_ +
middle_cost->communication_without_parameter_ +
right_cost->communication_without_parameter_;
double memory_cost =
left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
auto cost = std::make_shared<Cost>(computation, communication, decision);
......@@ -225,6 +231,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory_cost;
ret_cost_list->emplace_back(std::move(cost));
}
}
......@@ -267,5 +274,24 @@ void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op,
MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
}
}
Status Edge::CalculateMemoryCost() {
if (is_output_parameter_involve_ == -1) {
MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
return FAILED;
}
if (is_output_parameter_involve_ == 0) {
// In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
// unnecessary to keep them in memory.
for (auto& cost_kv : cost_map_) {
auto& cost_v = cost_kv.second;
if (!cost_v.empty()) {
cost_v[0]->memory_with_reuse_ = 0;
}
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
......@@ -133,7 +133,7 @@ class Edge {
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status CalculateMemoryCost() const { return SUCCESS; }
Status CalculateMemoryCost();
private:
std::string edge_name_;
......
......@@ -248,6 +248,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::
MS_EXCEPTION_IF_NULL(cost2);
MS_EXCEPTION_IF_NULL(cost3);
double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
double commmunication =
cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
double communication_without_para = cost1->communication_without_parameter_ +
......@@ -260,6 +261,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::
cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para);
cost->memory_with_reuse_ = memory;
ret.push_back(cost);
}
}
......@@ -288,6 +290,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) {
new_cost->communication_with_partial_para_ =
cost1->communication_without_parameter_ +
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
ret.push_back(new_cost);
}
}
......@@ -297,9 +300,14 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) {
}
CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) {
if (cost_list.empty() || cost_list[0]->computation_cost_ >= memory) {
return nullptr;
CostPtrList after_mem_filter;
// Filter out the valid costs
for (auto& a_cost : cost_list) {
if (a_cost->memory_with_reuse_ <= memory) {
after_mem_filter.emplace_back(std::move(a_cost));
}
}
std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) {
MS_EXCEPTION_IF_NULL(cost_x);
if (init == nullptr || cost_x->computation_cost_ < memory) {
......@@ -308,7 +316,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list,
return init;
};
CostPtr ret = nullptr;
return std::accumulate(cost_list.begin(), cost_list.end(), ret, LocalCompare);
return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare);
}
CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) {
......@@ -318,36 +326,46 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d
MS_LOG(ERROR) << "Final cost list is null.";
return nullptr;
}
CostPtr ret = cost_list[0];
MS_EXCEPTION_IF_NULL(ret);
if (ret->computation_cost_ >= memory) {
MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->computation_cost_
CostPtrList after_mem_filter;
double minimum_memory = DBL_MAX;
// Filter out the valid costs.
for (auto& a_cost : cost_list) {
if (a_cost->memory_with_reuse_ <= memory) {
after_mem_filter.emplace_back(std::move(a_cost));
} else if (a_cost->memory_with_reuse_ < minimum_memory) {
minimum_memory = a_cost->memory_with_reuse_;
}
}
if (after_mem_filter.empty()) {
MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
<< ", the memory capacity is: " << memory << ".";
return nullptr;
}
// Init the returned value with first cost.
CostPtr ret = after_mem_filter[0];
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_;
MS_LOG(INFO) << "minimum: " << minimum << ", computation_cost_: " << ret->computation_cost_
MS_LOG(INFO) << "Cost 0: "
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
<< ", communication_cost_: " << ret->communication_cost_
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
for (size_t i = 1; i < cost_list.size(); ++i) {
MS_EXCEPTION_IF_NULL(cost_list[i]);
if (cost_list[i]->computation_cost_ >= memory) {
MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_
<< ", is larger than the memory capacity: " << memory << ".";
break;
}
MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_
<< ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_
<< ", communication_cost_: " << cost_list[i]->communication_cost_
<< ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << ".";
auto tmp = costmodel_alpha_ * cost_list[i]->computation_cost_ +
costmodel_beta_ * cost_list[i]->communication_with_partial_para_;
MS_LOG(INFO) << "tmp: " << tmp;
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
<< ", computation_cost_: " << after_mem_filter[i]->computation_cost_
<< ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
<< ".";
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_;
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
if (minimum > tmp) {
minimum = tmp;
ret = cost_list[i];
MS_LOG(INFO) << "selected: " << i;
ret = after_mem_filter[i];
MS_LOG(INFO) << "Selected: " << i;
}
}
return ret;
......@@ -356,17 +374,21 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d
CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList>& all_cost_list,
double available_memory) {
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
double minimum = 0.0, total_memory = 0.0;
double minimum = DBL_MAX, total_memory = 0.0;
CostPtrList ret(all_cost_list.size(), nullptr);
// Check whether valid costs exist.
for (size_t i = 0; i < all_cost_list.size(); ++i) {
if (all_cost_list[i][0] == nullptr) {
MS_LOG(ERROR) << "The cost list " << i << " is empty.";
return ret;
} else {
total_memory += all_cost_list[i][0]->computation_cost_;
minimum += costmodel_alpha_ * all_cost_list[i][0]->computation_cost_ +
costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_;
ret[i] = all_cost_list[i][0];
double memory_i_cost = DBL_MAX;
for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
}
}
total_memory += memory_i_cost;
}
}
if (total_memory >= available_memory) {
......@@ -381,7 +403,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect
double tmp_memory = 0.0, tmp_minimum = 0.0;
for (size_t i = 0; i < selected_cost_list.size(); ++i) {
MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
tmp_memory += selected_cost_list[i]->computation_cost_;
tmp_memory += selected_cost_list[i]->memory_with_reuse_;
tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ +
costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_;
}
......@@ -816,6 +838,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
auto& tar_cost = tar_cost_list[k];
MS_EXCEPTION_IF_NULL(tar_cost);
double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
double communication =
op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
double communication_without_para = op_cost->communication_without_parameter_ +
......@@ -829,6 +852,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
MS_EXCEPTION_IF_NULL(tar_cost_list_new);
tar_cost_list_new->emplace_back(std::move(new_cost));
}
......@@ -894,6 +918,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
MS_EXCEPTION_IF_NULL(tar_cost);
double computation =
contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
double memory =
contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
double communication =
contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
double communication_without_para = contract_op_cost->communication_without_parameter_ +
......@@ -906,6 +932,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
tar_cost_list_new->emplace_back(std::move(new_cost));
}
}
......@@ -966,23 +993,22 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
for (auto& left_node_cost : left_node_clist_origin) {
MS_EXCEPTION_IF_NULL(left_node_cost);
double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
left_node_cost->computation_cost_ + right_edge_cost->computation_cost_ +
right_op_cost->computation_cost_;
left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ +
right_op_cost->communication_cost_;
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
double new_commu_without =
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ +
right_op_cost->communication_without_parameter_;
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
auto decision =
std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
left_op_stra, left_node_cost, right_op_stra, right_op_cost);
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost,
right_edge_cost, left_op_stra, left_node_cost);
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
new_cost->communication_without_parameter_ = new_commu_without;
new_cost->communication_with_partial_para_ =
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
new_cost->memory_with_reuse_ = new_memory;
left_node_clist_new->emplace_back(std::move(new_cost));
}
}
......@@ -1085,14 +1111,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n
succ_nodes_costs[0] = first_succ_node_cost;
double computation_cost = merged_node_cost->computation_cost_,
commu_cost = merged_node_cost->communication_cost_,
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
commu_without = merged_node_cost->communication_without_parameter_;
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
commu_without += succ_edges_costs[i]->communication_without_parameter_ +
succ_nodes_costs[i]->communication_without_parameter_;
if (i == 0) {
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
commu_without += succ_edges_costs[i]->communication_without_parameter_ +
succ_nodes_costs[i]->communication_without_parameter_;
} else {
computation_cost += succ_edges_costs[i]->computation_cost_;
memory_cost += succ_edges_costs[i]->memory_with_reuse_;
commu_cost += succ_edges_costs[i]->communication_cost_;
commu_without += succ_edges_costs[i]->communication_without_parameter_;
}
}
auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
......@@ -1100,6 +1134,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n
auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
new_cost->communication_without_parameter_ = commu_without;
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
new_cost->memory_with_reuse_ = memory_cost;
first_succ_node_clist_new->emplace_back(std::move(new_cost));
}
}
......@@ -1259,5 +1294,35 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c
}
return nullptr;
}
Status CostGraph::CorrectOpsMemoryCost() {
for (auto& one_op : ops_) {
if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
if (one_op->GetAliveSuccEdges().size() > 1) {
// Filter out the case when the TmpIdentity being used by multiple operators
std::map<size_t, int> output_count;
for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
output_count[output_index]++;
}
for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
if (output_count[output_index] <= 1) {
continue;
}
auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
MS_EXCEPTION_IF_NULL(next_op);
auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
<< ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
return FAILED;
}
output_count[output_index]--;
}
}
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
......@@ -187,6 +187,9 @@ class CostGraph {
size_t GetNumPairs() const { return edges_.size(); }
Status InitSelectedStrategy();
OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
// once (instead of multiple times), this method is used to correct this.
Status CorrectOpsMemoryCost();
// Needed by rec_parser
void add_inputs_tensor_name(const std::vector<std::string>& inputs_tensor_name) {
inputs_tensor_name_list_.push_back(inputs_tensor_name);
......
......@@ -17,6 +17,7 @@
#include "parallel/auto_parallel/operator_costmodel.h"
#include <random>
#include <algorithm>
#include "parallel/device_matrix.h"
#include "parallel/tensor_layout/tensor_redistribution.h"
......@@ -24,12 +25,44 @@ namespace mindspore {
namespace parallel {
void OperatorCost::set_is_parameter(const std::vector<bool>& is_parameter) { is_parameter_ = is_parameter; }
void OperatorCost::set_is_parameter_involve(const std::vector<bool>& is_parameter_inv) {
is_parameter_involve_ = is_parameter_inv;
}
void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; }
void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths,
const std::vector<size_t>& output_lengths) {
inputs_type_lengths_ = input_lengths;
outputs_type_lengths_ = output_lengths;
}
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs) const {
double result = 0.0;
if (output_parameter_involve_ == 1) {
// When this operator has multiple outputs, they all contributes to the memory.
for (size_t i = 0; i < outputs.size(); ++i) {
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
}
bool is_any_para_inv =
std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; });
if (is_any_para_inv) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_parameter_[i]) {
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
} else if (inputs_related_ && (!is_parameter_involve_[i])) {
// When the inputs of this operator are related, and they are not parameter-involved, then they are included
// in the memory cost.
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
}
}
}
}
return result;
}
// return the per device communication cost in the forward phase.
double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t&) const {
......@@ -72,11 +105,11 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, co
return result;
}
// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs, const int32_t&) const {
// In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
// In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
double result = 0.0;
TensorInfo output0 = outputs[0];
Shape input0_slice_shape = inputs[0].slice_shape();
......@@ -91,11 +124,11 @@ double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inpu
return result;
}
// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t& stage_id) const {
// In backward phase, the memory cost = (0 or 1) allreduce(slice(B))
// In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
double result = 0.0;
if (is_parameter_[1]) {
TensorInfo input1 = inputs[1]; // tensor B
......@@ -145,7 +178,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
return result;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const {
......@@ -154,7 +187,7 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>&
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const {
......@@ -189,17 +222,17 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, c
return result;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const {
// In the forward phase, the memory cost = slice(A)
// In the forward phase, the computation cost = slice(A)
TensorInfo input0 = inputs[0];
Shape input0_slice_shape = input0.slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&,
......@@ -221,17 +254,15 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::paralle
return 0.0;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&,
const int32_t&) const {
TensorInfo input0_info = inputs[0];
Shape input0_slice_shape = input0_info.slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
return 0.0;
}
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&,
......@@ -239,6 +270,11 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::
return 0.0;
}
// Return the per device PEAK memory cost contributed by this operator in a training iteration.
double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&) const {
return 0.0;
}
double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
const std::vector<mindspore::parallel::TensorInfo>&,
const int32_t&) const {
......@@ -284,11 +320,11 @@ double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, con
return result;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const {
// In forward phase, the memory cost = slice(A) + slice(B)
// In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape();
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
......@@ -296,12 +332,12 @@ double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& input
return result;
}
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
const std::vector<mindspore::parallel::TensorInfo>&,
const int32_t& stage_id) const {
// In backward phase, the memory cost = (0 or 1) allreduce(slice(B))
// In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
double result = 0.0;
if (is_parameter_[1]) {
TensorInfo input1 = inputs[1]; // tensor B
......@@ -337,16 +373,16 @@ double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std
return 0.0;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const {
// In onehot's forward phase, the memory cost = slice(A)
// In onehot's forward phase, the computation cost = slice(A)
Shape input0_slice_shape = inputs[0].slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const {
......@@ -367,12 +403,12 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<
return 0.0;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>&,
const int32_t&) const {
// In forward phase, the memory cost = slice(A) + slice(B)
// In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape();
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
......@@ -380,7 +416,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v
return result;
}
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo>&,
const std::vector<TensorInfo>&,
......@@ -410,7 +446,7 @@ double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const st
return 0.0;
}
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const {
......@@ -427,7 +463,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
}
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes
// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&,
......
......@@ -43,10 +43,20 @@ double ListProduct(std::vector<T> vec) {
// entries timing the length of each entry's data type
class OperatorCost {
public:
OperatorCost() {
explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) {
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
is_parameter_.push_back(false);
is_parameter_involve_.push_back(false);
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
}
}
OperatorCost() : inputs_related_(false) {
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
is_parameter_.push_back(false);
is_parameter_involve_.push_back(false);
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
}
......@@ -54,6 +64,8 @@ class OperatorCost {
virtual ~OperatorCost() = default;
void set_is_parameter(const std::vector<bool>& is_parameter);
void set_is_parameter_involve(const std::vector<bool>&);
void set_output_parameter_involve(int);
void SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, const std::vector<size_t>& output_lengths);
std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
......@@ -72,8 +84,19 @@ class OperatorCost {
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0;
virtual double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0;
// per device PEAK memory cost in a training iteration
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// plus necessary inputs.
virtual double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const;
protected:
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// pre-operator that has parameters as input.
std::vector<bool> is_parameter_involve_;
int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
// Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while
// Mul's two inputs are dependent (related).
bool inputs_related_;
// for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
std::vector<bool> is_parameter_;
// for each input and output, the followings record the number of bytes of each element
......@@ -85,7 +108,8 @@ using OperatorCostPtr = std::shared_ptr<OperatorCost>;
class MatMulCost : public OperatorCost {
public:
MatMulCost() = default;
explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
MatMulCost() : OperatorCost(true) {}
~MatMulCost() override = default;
// per device communication cost
......@@ -108,12 +132,12 @@ class MatMulCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override;
};
using MatMulCostPtr = std::shared_ptr<MatMulCost>;
class ActivationCost : public OperatorCost {
public:
ActivationCost() = default;
explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ActivationCost() : OperatorCost(false) {}
~ActivationCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -133,14 +157,14 @@ class ActivationCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override;
};
using ActivationCostPtr = std::shared_ptr<ActivationCost>;
using TransposeCost = ActivationCost;
using TransposeCostPtr = std::shared_ptr<TransposeCost>;
class SoftmaxCost : public OperatorCost {
public:
SoftmaxCost() = default;
explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
SoftmaxCost() : OperatorCost(false) {}
~SoftmaxCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -160,12 +184,12 @@ class SoftmaxCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t&) const override;
};
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
class TmpIdentityCost : public OperatorCost {
public:
TmpIdentityCost() = default;
explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
TmpIdentityCost() : OperatorCost(false) {}
~TmpIdentityCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -184,12 +208,15 @@ class TmpIdentityCost : public OperatorCost {
const int32_t& stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override;
// per device PEAK memory cost in a training iteration
double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override;
};
using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
class BatchParallelCost : public OperatorCost {
public:
BatchParallelCost() = default;
explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
BatchParallelCost() : OperatorCost(false) {}
~BatchParallelCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -217,7 +244,8 @@ using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>;
class VirtualDatasetCost : public OperatorCost {
public:
VirtualDatasetCost() = default;
explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
VirtualDatasetCost() : OperatorCost(false) {}
~VirtualDatasetCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -244,12 +272,17 @@ class VirtualDatasetCost : public OperatorCost {
const int32_t&) const override {
return 0.0;
}
// per device PEAK memory cost in a training iteration
double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override {
return 0.0;
}
};
using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>;
class GeneratorBaseCost : public OperatorCost {
public:
GeneratorBaseCost() = default;
explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GeneratorBaseCost() : OperatorCost(false) {}
~GeneratorBaseCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -283,7 +316,8 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
class PReLUCost : public OperatorCost {
public:
PReLUCost() = default;
explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
PReLUCost() : OperatorCost(true) {}
~PReLUCost() override = default;
// per device communication cost
......@@ -310,7 +344,8 @@ using PReLUCostPtr = std::shared_ptr<PReLUCost>;
class OneHotCost : public OperatorCost {
public:
OneHotCost() = default;
explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
OneHotCost() : OperatorCost(true) {}
~OneHotCost() override = default;
// per device communication cost
......@@ -337,7 +372,8 @@ using OneHotCostPtr = std::shared_ptr<OneHotCost>;
class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
public:
SoftmaxCrossEntropyWithLogitsCost() = default;
explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {}
~SoftmaxCrossEntropyWithLogitsCost() override = default;
// per device communication cost
......@@ -364,7 +400,8 @@ using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropy
class ReshapeCost : public OperatorCost {
public:
ReshapeCost() = default;
explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ReshapeCost() : OperatorCost(true) {}
~ReshapeCost() override = default;
......@@ -396,7 +433,8 @@ using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
class ArithmeticCost : public OperatorCost {
public:
ArithmeticCost() = default;
explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ArithmeticCost() : OperatorCost(false) {}
~ArithmeticCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -425,7 +463,8 @@ using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
class ReduceMethodCost : public OperatorCost {
public:
ReduceMethodCost() = default;
explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ReduceMethodCost() : OperatorCost(true) {}
~ReduceMethodCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -455,7 +494,8 @@ using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>;
class ReduceMeanCost : public ReduceMethodCost {
public:
ReduceMeanCost() = default;
explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {}
ReduceMeanCost() : ReduceMethodCost(true) {}
~ReduceMeanCost() override = default;
double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -465,7 +505,8 @@ using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>;
class GetNextCost : public OperatorCost {
public:
GetNextCost() = default;
explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GetNextCost() : OperatorCost(false) {}
~GetNextCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -499,7 +540,8 @@ using GetNextCostPtr = std::shared_ptr<GetNextCost>;
class DropOutCost : public OperatorCost {
public:
DropOutCost() = default;
explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
DropOutCost() : OperatorCost(true) {}
~DropOutCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......@@ -530,7 +572,8 @@ using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class GatherV2Cost : public OperatorCost {
public:
GatherV2Cost() = default;
explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GatherV2Cost() : OperatorCost(true) {}
~GatherV2Cost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
......
......@@ -51,7 +51,7 @@ class Activation : public ActivationBase {
public:
Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {}
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
~Activation() override = default;
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
......@@ -102,7 +102,7 @@ class Softmax : public ActivationBase {
public:
explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {}
~Softmax() override = default;
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
......
......@@ -32,8 +32,8 @@ namespace parallel {
class ArithmeticBase : public OperatorInfo {
public:
ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>()) {}
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
~ArithmeticBase() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
......@@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo {
class SubInfo : public ArithmeticBase {
public:
SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~SubInfo() override = default;
};
......@@ -64,21 +64,21 @@ class TensorAddInfo : public ArithmeticBase {
public:
TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~TensorAddInfo() override = default;
};
class MulInfo : public ArithmeticBase {
public:
MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MulInfo() override = default;
};
class DivInfo : public ArithmeticBase {
public:
DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~DivInfo() override = default;
};
......@@ -86,7 +86,7 @@ class RealDivInfo : public ArithmeticBase {
public:
RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~RealDivInfo() override = default;
};
......@@ -94,14 +94,14 @@ class FloorDivInfo : public ArithmeticBase {
public:
FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~FloorDivInfo() override = default;
};
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~PowInfo() override = default;
};
......@@ -109,7 +109,7 @@ class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterInfo() override = default;
};
......@@ -117,7 +117,7 @@ class AssignSubInfo : public ArithmeticBase {
public:
AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~AssignSubInfo() override = default;
};
} // namespace parallel
......
......@@ -29,9 +29,13 @@ namespace mindspore {
namespace parallel {
class BatchParallelInfo : public OperatorInfo {
public:
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)),
dev_num_(1) {}
~BatchParallelInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
......@@ -58,7 +62,7 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
public:
SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape,
const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {}
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {}
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
void ReComputeBatchSplitFlagList() override;
};
......
......@@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo {
public:
BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {}
~BiasAddInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
......
......@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
......@@ -31,7 +32,7 @@ class EqualInfo : public ArithmeticBase {
public:
EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~EqualInfo() override = default;
};
......@@ -39,7 +40,7 @@ class NotEqualInfo : public ArithmeticBase {
public:
NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~NotEqualInfo() override = default;
};
......@@ -47,7 +48,7 @@ class MaximumInfo : public ArithmeticBase {
public:
MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MaximumInfo() override = default;
};
......@@ -55,7 +56,7 @@ class MinimumInfo : public ArithmeticBase {
public:
MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MinimumInfo() override = default;
};
} // namespace parallel
......
......@@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
public:
DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {}
~DropoutDoMaskInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
......
......@@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo {
public:
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
~GetNextInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
......
......@@ -36,7 +36,8 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
public:
SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs,
std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {}
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
......
......@@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// It does not matter whether the output tensor is transposed or not.
double computation_cost =
cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
......
......@@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo {
public:
MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
~MatMulBase() override = default;
Status Init(const StrategyPtr& strategy) override;
......
......@@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo {
public:
OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {}
~OneHotInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
......
......@@ -1035,11 +1035,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
return FAILED;
}
int32_t stage_id = strategy->GetInputStage();
double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double computation_cost =
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
......@@ -1096,7 +1097,38 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
return FAILED;
}
is_parameter_ = is_parameter;
cost()->set_is_parameter(is_parameter);
operator_cost()->set_is_parameter(is_parameter);
return SUCCESS;
}
Status OperatorInfo::CalculateMemoryCost() {
// First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to
// calculate memory cost.
if (is_parameter_involve_.size() != is_parameter_.size()) {
MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
return FAILED;
}
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
// Set the memory cost in the 'strategy_cost_'
for (auto& swc : strategy_cost_) {
auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
swc->cost_list[0]->memory_with_reuse_ = mem_cost;
}
return SUCCESS;
}
Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
for (auto& swc : strategy_cost_) {
double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
if (swc->cost_list[0]->memory_with_reuse_ < 0) {
MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
<< ", the parameter memory cost is: " << parameter_mem_cost;
return FAILED;
}
}
return SUCCESS;
}
......@@ -1193,7 +1225,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
}
inputs_type_lengths_ = input_lengths;
outputs_type_lengths_ = output_lengths;
cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
return SUCCESS;
}
......@@ -1221,7 +1253,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
}
double OperatorInfo::GetForwardMemoryCostFromCNode() {
return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
}
} // namespace parallel
......
......@@ -60,7 +60,7 @@ class OperatorInfo {
outputs_shape_(std::move(outputs_shape)),
attrs_(std::move(attrs)),
is_alive_(true),
cost_(cost),
operator_cost_(cost),
outputs_type_() {
std::vector<bool> not_parameteter(inputs_shape_.size(), false);
is_parameter_ = not_parameteter;
......@@ -83,8 +83,8 @@ class OperatorInfo {
// Given the stage_id (which indicates the number of devices),
// generate all strategies for this operator
virtual Status GenerateStrategies(int32_t stage_id) = 0;
const OperatorCostPtr& cost() const { return cost_; }
void set_cost(const OperatorCostPtr& cost) { cost_ = cost; }
const OperatorCostPtr& operator_cost() const { return operator_cost_; }
void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; }
virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0;
virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies();
......@@ -98,7 +98,7 @@ class OperatorInfo {
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status CalculateMemoryCost() const { return SUCCESS; }
Status CalculateMemoryCost();
int ComputeOpAndPrevEdgeParameterInvolved();
ForwardOp forward_op() const { return forward_op_; }
......@@ -125,7 +125,7 @@ class OperatorInfo {
void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
std::vector<size_t> GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); }
std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); }
void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) {
selected_strategy_ = s_strategy;
selected_cost_ = cost;
......@@ -142,6 +142,10 @@ class OperatorInfo {
void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; }
void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
const std::string& refkey_parameter_name() const { return refkey_parameter_name_; }
// When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
// multiple times. This method is to correct this, and makes the cost is calulated only once.
Status CorrectMemoryCost(size_t input_index);
int is_output_parameter_involve() const { return is_output_parameter_involve_; }
int used_devices() const { return used_devices_; }
// needed by rec_parser
void set_type(const std::string& type) { type_ = type; }
......@@ -234,7 +238,7 @@ class OperatorInfo {
int32_t used_devices_ = -1;
private:
OperatorCostPtr cost_;
OperatorCostPtr operator_cost_;
std::vector<TypePtr> outputs_type_;
};
......
......@@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo {
public:
PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {}
~PReLUInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
......
......@@ -109,7 +109,7 @@ Status ReduceMethod::GetAttrs() {
}
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
}
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(cost());
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
if (reducemethodcost == nullptr) {
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
return FAILED;
......
......@@ -34,7 +34,7 @@ class ReduceMethod : public OperatorInfo {
public:
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {}
~ReduceMethod() override = default;
Status Init(const StrategyPtr &strategy) override;
......
......@@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo {
public:
ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)),
dev_num_(0),
input_layout_set_flag_(false),
output_layout_set_flag_(false) {}
......
......@@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
public:
TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs,
const std::string& name = IDENTITY_INFO)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {}
~TmpIdentityInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
......
......@@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo {
public:
TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {}
~TransposeInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
......
......@@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo {
public:
VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {}
~VirtualDatasetInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
......
......@@ -874,11 +874,15 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
// Calculate operators' memory usage
if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed.";
MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed.";
}
// Calculate edges' memory usage
if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed.";
MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed.";
}
// Correct memory usage caused by TmpIdentity
if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed.";
}
} else {
MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed.";
......
......@@ -159,6 +159,7 @@ Status TensorRedistribution::ComputeCost() {
backward_comm_cost_ += prod;
comm_cost_ += 2.0 * prod;
computation_cost_ += prod;
memory_cost_ += prod;
} else if (str == CONCAT_BY_AXIS) {
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// computation cost = before_slice_shape
......@@ -175,20 +176,25 @@ Status TensorRedistribution::ComputeCost() {
if (concat_dim == 0) {
// computation cost = all_gather
computation_cost_ += prod;
memory_cost_ += prod * dev_num;
} else {
// computation cost = all_gather + split + concat
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
}
} else {
// There is only computation cost in SplitByAxis.
// computation cost = before_slice_shape
computation_cost_ += prod;
// This addtion may be erroneous
memory_cost_ += prod;
}
}
if (reshape_flag()) {
Shape prev_slice_shape = from_.slice_shape().array();
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
computation_cost_ += 2.0 * prev_prod;
memory_cost_ += 2.0 * prev_prod;
}
return Status::SUCCESS;
}
......
......@@ -42,6 +42,7 @@ class TensorRedistribution {
forward_comm_cost_(0.0),
backward_comm_cost_(0.0),
computation_cost_(0.0),
memory_cost_(0.0),
construct_op_flag_(construct_op_flag),
keep_reshape_(keep_reshape) {}
Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list);
......@@ -54,6 +55,7 @@ class TensorRedistribution {
double computation_cost() const { return computation_cost_; }
double forward_comm_cost() const { return forward_comm_cost_; }
double backward_comm_cost() const { return backward_comm_cost_; }
double memory_cost() const { return memory_cost_; }
private:
Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout,
......@@ -72,7 +74,12 @@ class TensorRedistribution {
double forward_comm_cost_;
// backward communication cost
double backward_comm_cost_;
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
// inputs.
double computation_cost_;
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
// calculated by the outputs.
double memory_cost_;
bool construct_op_flag_;
bool keep_reshape_;
};
......
......@@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) {
act_ptr_->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info();
ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.computation_cost_);
ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.communication_cost_);
}
}
......@@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) {
soft_ptr_->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info();
ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.computation_cost_);
ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.communication_cost_);
}
}
......
......@@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
matmul1->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info();
ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(matmul1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.computation_cost_);
break;
}
......@@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
replica_inputs_info.push_back(replica_input1_info);
ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(matmul3->operator_cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
cost.computation_cost_);
break;
}
......
......@@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
tensor_add->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info();
double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
double memory_cost0 = tensor_add->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
double memory_cost1 = cost.computation_cost_;
bool memory = memory_cost0 - memory_cost1 <= 1.0;
double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
double comm_cost0 = tensor_add->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
double comm_cost1 = cost.communication_cost_;
bool comm = comm_cost0 - comm_cost1 <= 1.0;
......@@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
tensor_add1->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info();
double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
double memory_cost0 = tensor_add1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
double memory_cost1 = cost.computation_cost_;
bool memory = memory_cost0 - memory_cost1 <= 1.0;
double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
double comm_cost0 = tensor_add1->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
double comm_cost1 = cost.communication_cost_;
bool comm = comm_cost0 - comm_cost1 <= 1.0;
......
......@@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
identity_ptr->Init(sp);
std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info();
ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.computation_cost_);
ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.communication_cost_);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册