提交 9f4b8a3c 编写于 作者: X Xiaoda Zhang

changing the succive edges order in GetAliveSuccEdges() so that Triangle and...

changing the succive edges order in GetAliveSuccEdges() so that Triangle and Star Elimination can be merged into particular node; adding some check information
上级 d9c74e0a
......@@ -211,13 +211,14 @@ struct ContractEliminationDecision : public Decision {
*/
struct TriangleEliminationDecision : public Decision {
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
StrategyPtr left_stra, CostPtr l_node_cost)
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra)
: eliminated_op_strategy_(std::move(elimi_stra)),
eliminated_op_cost_(std::move(elimi_op_cost)),
left_edge_cost_(std::move(l_edge_cost)),
right_edge_cost_(std::move(r_edge_cost)),
left_node_strategy_(std::move(left_stra)),
left_node_cost_(std::move(l_node_cost)) {
left_node_cost_(std::move(l_node_cost)),
right_node_strategy_(std::move(right_stra)) {
type_ = DecisionType::TRIANGLE_ELIMINATION;
}
......@@ -227,6 +228,7 @@ struct TriangleEliminationDecision : public Decision {
CostPtr right_edge_cost_;
StrategyPtr left_node_strategy_;
CostPtr left_node_cost_;
StrategyPtr right_node_strategy_;
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
};
......
......@@ -85,7 +85,9 @@ Status GetStrategy(const CostGraphPtr &graph) {
right_edge = tmp;
}
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge);
auto right_node = l_r_edge->next_operator();
auto elimi =
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
eliminations.emplace_back(std::move(elimi));
}
auto star_center = graph->CheckStarElimination();
......@@ -181,6 +183,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
auto left_edge = elimination->left_edge_;
auto eliminated_node = elimination->eliminated_node_;
auto right_edge = elimination->right_edge_;
auto right_node = elimination->right_node_;
auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
......@@ -188,6 +191,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
right_edge->set_selected_cost(decision->right_edge_cost_);
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
} else if ((*rit)->isa<StarElimination>()) {
auto elimination = (*rit)->cast<StarEliminationPtr>();
......@@ -206,6 +210,9 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
for (size_t k = 1; k < succ_nodes.size(); ++k) {
succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]);
}
MS_LOG(INFO) << "Recover starElimination succeeded.";
} else {
MS_LOG(ERROR) << "Unknown Elimination type.";
......
......@@ -102,17 +102,20 @@ struct ContractElimination : public Elimination {
// Triangle Elimination
struct TriangleElimination : public Elimination {
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge)
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
OperatorInfoPtr r_node)
: Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
eliminated_node_(std::move(elim_node)),
left_edge_(std::move(l_edge)),
left_node_(std::move(l_node)),
right_edge_(std::move(r_edge)) {}
right_edge_(std::move(r_edge)),
right_node_(std::move(r_node)) {}
OperatorInfoPtr eliminated_node_;
EdgePtr left_edge_;
OperatorInfoPtr left_node_;
EdgePtr right_edge_;
OperatorInfoPtr right_node_;
MS_DECLARE_PARENT(TriangleElimination, Elimination);
};
......
......@@ -1111,8 +1111,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost,
right_edge_cost, left_op_stra, left_node_cost);
auto decision = std::make_shared<TriangleEliminationDecision>(
elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
new_cost->communication_without_parameter_ = new_commu_without;
new_cost->communication_with_partial_para_ =
......
......@@ -546,10 +546,14 @@ std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
for (auto &edge : succ_edges_) {
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
ret.push_back(edge);
} else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
// CAST is ordered in front of L2NORMALIZE
ret.push_back(edge);
}
}
for (auto &edge : succ_edges_) {
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) {
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
(edge->next_operator()->name().find(CAST) == std::string::npos)) {
ret.push_back(edge);
}
}
......@@ -1279,10 +1283,18 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
CheckGlobalDeviceManager();
auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size();
if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) {
cost->computation_cost_ -= 1.0;
cost->communication_cost_ -= 1.0;
cost->communication_with_partial_para_ -= 1.0;
cost->communication_without_parameter_ -= 1.0;
if (cost->computation_cost_ > 1.0) {
cost->computation_cost_ -= 1.0;
}
if (cost->communication_cost_ > 1.0) {
cost->communication_cost_ -= 1.0;
}
if (cost->communication_with_partial_para_ > 1.0) {
cost->communication_with_partial_para_ -= 1.0;
}
if (cost->communication_without_parameter_ > 1.0) {
cost->communication_without_parameter_ -= 1.0;
}
}
}
}
......@@ -1290,5 +1302,15 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
double OperatorInfo::GetForwardMemoryCostFromCNode() {
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
}
void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
MS_EXCEPTION_IF_NULL(s_strategy);
if (!s_strategy->IsEqual(selected_strategy_)) {
MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:";
PrintStrategy(selected_strategy_);
MS_LOG(INFO) << "The minimal strategy:";
PrintStrategy(s_strategy);
}
}
} // namespace parallel
} // namespace mindspore
......@@ -138,6 +138,7 @@ class OperatorInfo {
}
StrategyPtr selected_strategy() const { return selected_strategy_; }
CostPtr selected_cost() const { return selected_cost_; }
void CheckSelectedStrategy(const StrategyPtr &);
Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
const std::vector<ValuePtr> &input_value() const { return input_value_; }
......
......@@ -48,6 +48,16 @@ class Strategy {
}
void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; }
bool IsEqual(const StrategyPtr &another_stra) {
if (another_stra == nullptr) {
return false;
}
if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) {
return false;
}
return true;
}
private:
const int32_t stage_;
......
......@@ -64,5 +64,23 @@ TEST_F(TestStrategy, GetInputDim) {
ASSERT_EQ(inputs, inputs_test);
}
TEST_F(TestStrategy, IsEqual) {
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
std::vector<int32_t> dimension1 = {8, 1};
std::vector<int32_t> dimension2 = {1, 8};
std::vector<std::vector<int32_t>> inputs1 = {dimension1};
std::vector<std::vector<int32_t>> inputs2 = {dimension1};
std::vector<std::vector<int32_t>> inputs3 = {dimension2};
std::vector<std::vector<int32_t>> inputs4 = {dimension1, dimension2};
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);
StrategyPtr stra3 = std::make_shared<Strategy>(stage3, inputs3);
StrategyPtr stra4 = std::make_shared<Strategy>(stage4, inputs4);
ASSERT_EQ(stra1->IsEqual(stra2), true);
ASSERT_EQ(stra1->IsEqual(stra3), false);
ASSERT_EQ(stra1->IsEqual(stra4), false);
}
} // namespace parallel
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册