From 716def7c0a64385977e78464e73035726a8770d3 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Mon, 11 May 2020 10:39:52 +0800 Subject: [PATCH] move InferStraByTensorInfo to tensor_info.h --- .../parallel/auto_parallel/graph_costmodel.cc | 6 +- .../ccsrc/parallel/ops_info/reshape_info.cc | 52 +++++++++++++++++ .../ccsrc/parallel/ops_info/reshape_info.h | 3 + .../ccsrc/parallel/step_auto_parallel.cc | 56 ++----------------- mindspore/ccsrc/parallel/step_auto_parallel.h | 2 - .../parallel/tensor_layout/tensor_info.h | 11 ++++ .../parallel/test_auto_parallel_reshape.py | 6 ++ 7 files changed, 80 insertions(+), 56 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index d376e3221..9930cf704 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -1377,7 +1377,6 @@ Status CostGraph::InitSelectedStrategy() { if (pre_iter != in_edges.end()) { MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); int32_t pre_index = reshape_info->pre_operator_index(); - Dimensions stra; TensorInfo pre_info; if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; @@ -1385,7 +1384,10 @@ Status CostGraph::InitSelectedStrategy() { pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; } reshape_info->SetInputLayout(pre_info.tensor_layout()); - InferStraByTensorInfo(pre_info, &stra); + Dimensions stra = pre_info.InferStrategy(); + if (stra.empty()) { + MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; + } std::vector stra_inputs = {stra}; StrategyPtr reshape_stra = std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index b191e2219..40b8b79c4 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -440,5 +440,57 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { } return SUCCESS; } + +Status ReshapeInfo::GenetateStrategyCosts(const std::vector> &pre_stra_costs, + const std::vector> &next_stra_costs, + int32_t out_index, int32_t in_index, bool is_prev_param) { + for (auto pre_stra_cost : pre_stra_costs) { + std::vector pre_out_tensor_infos; + if (is_prev_param) { + pre_out_tensor_infos = pre_stra_cost->inputs_ptr; + } else { + pre_out_tensor_infos = pre_stra_cost->outputs_ptr; + } + if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { + MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; + return FAILED; + } + TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; + TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); + SetInputLayout(pre_out_tensor_layout); + // infer pre_node output strategy from output_layout. + Dimensions stra = pre_out_tensor_info.InferStrategy(); + if (stra.empty()) { + MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; + return FAILED; + } + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); + if (next_stra_costs.empty()) { + if (Init(nullptr) == FAILED) { + MS_LOG(ERROR) << "Failure:operator reshape init failed"; + return FAILED; + } + SetCostForReshape(reshape_stra); + continue; + } + for (auto next_stra_cost : next_stra_costs) { + std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; + if (next_in_tensor_infos.size() <= IntToSize(in_index)) { + MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; + return FAILED; + } + TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; + TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); + SetOutputLayout(next_in_tensor_layout); + if (Init(nullptr) == FAILED) { + MS_LOG(ERROR) << "Failure:operator reshape init failed"; + return FAILED; + } + SetCostForReshape(reshape_stra); + } + } + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index a711e9cb8..2b3bb91ab 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -56,6 +56,9 @@ class ReshapeInfo : public OperatorInfo { void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } + Status GenetateStrategyCosts(const std::vector> &pre_stra_costs, + const std::vector> &next_stra_costs, int32_t out_index, + int32_t in_index, bool is_prev_param); Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 0447285e5..9b0c3b111 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -999,18 +999,6 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator return false; } -void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) { - Shape shape = pre_out_tensor_info.shape(); - Shape slice_shape = pre_out_tensor_info.slice_shape(); - for (size_t i = 0; i < shape.size(); ++i) { - if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) { - MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator"; - } - int32_t dim = (int32_t)(shape[i] / slice_shape[i]); - (*stra).push_back(dim); - } -} - void ReshapeCostCompute(const std::vector &all_nodes) { for (auto node : all_nodes) { auto cnode = node->cast(); @@ -1054,46 +1042,10 @@ void ReshapeCostCompute(const std::vector &all_nodes) { reshape_info->set_next_operator_name(next_operator_info->name()); reshape_info->set_next_operator_index(in_index); } - for (auto pre_stra_cost : pre_stra_costs) { - std::vector pre_out_tensor_infos; - if (pre_node->isa()) { - pre_out_tensor_infos = pre_stra_cost->inputs_ptr; - } else { - pre_out_tensor_infos = pre_stra_cost->outputs_ptr; - } - if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { - MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; - } - TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; - TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); - reshape_info->SetInputLayout(pre_out_tensor_layout); - // infer pre_node output strategy from output_layout. - Dimensions stra; - InferStraByTensorInfo(pre_out_tensor_info, &stra); - std::vector stra_inputs = {stra}; - StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); - if (next_stra_costs.empty()) { - if (reshape_info->Init(nullptr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; - } - // set cost for each input_layout and output_layout pairs. - reshape_info->SetCostForReshape(reshape_stra); - continue; - } - for (auto next_stra_cost : next_stra_costs) { - std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; - if (next_in_tensor_infos.size() <= IntToSize(in_index)) { - MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; - } - TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; - TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); - reshape_info->SetOutputLayout(next_in_tensor_layout); - if (reshape_info->Init(nullptr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; - } - // set cost for each input_layout and output_layout pairs. - reshape_info->SetCostForReshape(reshape_stra); - } + bool is_prev_param = pre_node->isa(); + if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != + SUCCESS) { + MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; } } } diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.h b/mindspore/ccsrc/parallel/step_auto_parallel.h index cf05a36fe..fff9dfa4c 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.h +++ b/mindspore/ccsrc/parallel/step_auto_parallel.h @@ -51,8 +51,6 @@ void ConstructCostGraphEdges(const std::vector &all_nodes); void AugmentCostGraph(const std::vector &all_nodes); -void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra); - Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root); Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root); diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h index 43286317c..0eee736ce 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h @@ -46,6 +46,17 @@ class TensorInfo { Shape shape() const { return shape_; } void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } std::vector reduce_dim() const { return reduce_dim_; } + Dimensions InferStrategy() const { + Dimensions stra; + for (size_t i = 0; i < shape_.size(); ++i) { + if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) { + return stra; + } + int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]); + stra.push_back(dim); + } + return stra; + } private: TensorLayout tensor_layout_; diff --git a/tests/ut/python/parallel/test_auto_parallel_reshape.py b/tests/ut/python/parallel/test_auto_parallel_reshape.py index 1bce73361..bb2116eec 100644 --- a/tests/ut/python/parallel/test_auto_parallel_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_reshape.py @@ -86,6 +86,7 @@ def test_reshape_auto_1(): net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() _executor.compile(net, x) def test_reshape_auto_2(): @@ -112,6 +113,7 @@ def test_reshape_auto_2(): net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() _executor.compile(net, x) def test_reshape_auto_3(): @@ -135,6 +137,7 @@ def test_reshape_auto_3(): net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() _executor.compile(net, x) def test_reshape_auto_4(): @@ -159,6 +162,7 @@ def test_reshape_auto_4(): net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() _executor.compile(net, x) @@ -208,6 +212,7 @@ def test_reshape_auto_5(): net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() _executor.compile(net, x, y) def test_reshape_auto_6(): @@ -254,4 +259,5 @@ def test_reshape_auto_6(): net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() _executor.compile(net, x, y) -- GitLab