提交 9a717aa1 编写于 作者: X Xiaoda Zhang 提交者: 高东海

change_star_elimination: make the non-identity triangle_eliminatin exact

上级 29ab2c10
......@@ -948,10 +948,12 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) {
return target_op;
}
void CostGraph::CreateTriangleEliminationSubCostListForIdentity(
StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, StrategyPtr right_op_stra, const CostPtr& right_op_cost,
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) {
void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
StrategyPtr right_op_stra, const CostPtr& right_op_cost,
const CostPtrList& elimi_op_clist,
const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
const CostPtrList& left_node_clist_origin,
CostPtrList* left_node_clist_new) {
MS_EXCEPTION_IF_NULL(right_edge_cost);
MS_EXCEPTION_IF_NULL(right_op_cost);
MS_EXCEPTION_IF_NULL(left_node_clist_new);
......@@ -985,93 +987,20 @@ void CostGraph::CreateTriangleEliminationSubCostListForIdentity(
}
}
void CostGraph::CreateTriangleEliminationSubCostListForOthers(
StrategyPtr elimi_op_stra, StrategyPtr left_node_stra, StrategyPtr right_node_stra, const CostPtr& right_op_cost,
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) {
CostPtr elimi_op_determined = nullptr, left_edge_determined = nullptr, init_ele = nullptr;
std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) {
MS_EXCEPTION_IF_NULL(cost_x);
if ((init == nullptr) || (cost_x->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
init = cost_x;
}
return init;
};
// Find a feasible elimi_op_clist
elimi_op_determined = std::accumulate(elimi_op_clist.begin(), elimi_op_clist.end(), init_ele, LocalCompare);
init_ele = nullptr;
// Find a feasible left_edge_cost
left_edge_determined = std::accumulate(left_edge_clist.begin(), left_edge_clist.end(), init_ele, LocalCompare);
if ((elimi_op_determined == nullptr) || (left_edge_determined == nullptr)) {
return;
}
if ((elimi_op_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY) ||
(left_edge_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY)) {
return;
}
for (auto& left_node_cost : left_node_clist_origin) {
MS_EXCEPTION_IF_NULL(left_node_cost);
MS_EXCEPTION_IF_NULL(right_op_cost);
double new_memory_cost = left_node_cost->memory_cost_ + elimi_op_determined->memory_cost_ +
left_edge_determined->memory_cost_ + right_edge_cost->memory_cost_ +
right_op_cost->memory_cost_;
double commu_cost = left_node_cost->communication_cost_ + elimi_op_determined->communication_cost_ +
left_edge_determined->communication_cost_ + right_edge_cost->communication_cost_ +
right_op_cost->communication_cost_;
double commu_without =
left_node_cost->communication_without_parameter_ + elimi_op_determined->communication_without_parameter_ +
left_edge_determined->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ +
right_op_cost->communication_without_parameter_;
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_determined,
left_edge_determined, right_edge_cost, left_node_stra,
left_node_cost, right_node_stra, right_op_cost);
auto new_cost = std::make_shared<Cost>(new_memory_cost, commu_cost, decision);
new_cost->communication_without_parameter_ = commu_without;
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
left_node_clist_new->emplace_back(std::move(new_cost));
}
}
void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist,
const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra,
const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra,
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist,
const CostPtrList& left_node_clist_origin,
CostPtrList* left_node_clist_new) {
// The reason for separately dealing with when the 'elimi_op' is 'TMPIDENTITY_INFO' or others is that
// when 'elimi_op' is TMPIDENTITY_INFO, the computation is limited, while 'elimi_op' is others, the computation
// may be huge
MS_EXCEPTION_IF_NULL(elimi_op);
if (elimi_op->name().find(TMPIDENTITY_INFO_NAME) != std::string::npos) {
for (auto& right_node_cost : right_node_clist) {
MS_EXCEPTION_IF_NULL(right_node_cost);
for (auto& right_edge_cost : right_edge_clist) {
MS_EXCEPTION_IF_NULL(right_edge_cost);
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) &&
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
// Exact computation for TMPIDENTITY_INFO_NAME case
CreateTriangleEliminationSubCostListForIdentity(elimi_op_stra, left_node_stra, right_node_stra,
right_node_cost, elimi_op_clist, left_edge_clist,
right_edge_cost, left_node_clist_origin, left_node_clist_new);
}
}
}
} else {
for (auto& right_node_cost : right_node_clist) {
MS_EXCEPTION_IF_NULL(right_node_cost);
for (auto& right_edge_cost : right_edge_clist) {
MS_EXCEPTION_IF_NULL(right_edge_cost);
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) &&
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
// Approximate computation for other case
CreateTriangleEliminationSubCostListForOthers(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
elimi_op_clist, left_edge_clist, right_edge_cost,
left_node_clist_origin, left_node_clist_new);
}
}
for (auto& right_node_cost : right_node_clist) {
MS_EXCEPTION_IF_NULL(right_node_cost);
for (auto& right_edge_cost : right_edge_clist) {
MS_EXCEPTION_IF_NULL(right_edge_cost);
CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
left_node_clist_new);
}
}
}
......
......@@ -163,14 +163,9 @@ class CostGraph {
void CreateTriangleEliminationCostList(const OperatorInfoPtr&, const CostPtrList&, const CostPtrList&,
const StrategyPtr&, const StrategyPtr&, const StrategyPtr&, const CostPtrList&,
const CostPtrList&, const CostPtrList&, CostPtrList*);
// Given the relevant costlist, create the TriangleElimination cost for eliminating TmpIdentityInfo
void CreateTriangleEliminationSubCostListForIdentity(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&,
const CostPtrList&, const CostPtrList&, const CostPtr&,
const CostPtrList&, CostPtrList*);
// Given the relevant costlist, create the TriangleElimination cost for eliminating other operators
void CreateTriangleEliminationSubCostListForOthers(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&,
const CostPtrList&, const CostPtrList&, const CostPtr&,
const CostPtrList&, CostPtrList*);
// Given the relevant costlist, create the TriangleElimination cost
void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, const CostPtrList&,
const CostPtrList&, const CostPtr&, const CostPtrList&, CostPtrList*);
// Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op
// NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册