提交 716def7c 编写于 作者: Y yao_yf

move InferStraByTensorInfo to tensor_info.h

上级 dd2062bf
......@@ -1377,7 +1377,6 @@ Status CostGraph::InitSelectedStrategy() {
if (pre_iter != in_edges.end()) {
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
int32_t pre_index = reshape_info->pre_operator_index();
Dimensions stra;
TensorInfo pre_info;
if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
......@@ -1385,7 +1384,10 @@ Status CostGraph::InitSelectedStrategy() {
pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
}
reshape_info->SetInputLayout(pre_info.tensor_layout());
InferStraByTensorInfo(pre_info, &stra);
Dimensions stra = pre_info.InferStrategy();
if (stra.empty()) {
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
}
std::vector<Dimensions> stra_inputs = {stra};
StrategyPtr reshape_stra =
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
......
......@@ -440,5 +440,57 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
}
return SUCCESS;
}
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
int32_t out_index, int32_t in_index, bool is_prev_param) {
for (auto pre_stra_cost : pre_stra_costs) {
std::vector<TensorInfo> pre_out_tensor_infos;
if (is_prev_param) {
pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
} else {
pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
}
if (pre_out_tensor_infos.size() <= IntToSize(out_index)) {
MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
return FAILED;
}
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
SetInputLayout(pre_out_tensor_layout);
// infer pre_node output strategy from output_layout.
Dimensions stra = pre_out_tensor_info.InferStrategy();
if (stra.empty()) {
MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
return FAILED;
}
std::vector<Dimensions> stra_inputs = {stra};
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
if (next_stra_costs.empty()) {
if (Init(nullptr) == FAILED) {
MS_LOG(ERROR) << "Failure:operator reshape init failed";
return FAILED;
}
SetCostForReshape(reshape_stra);
continue;
}
for (auto next_stra_cost : next_stra_costs) {
std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
if (next_in_tensor_infos.size() <= IntToSize(in_index)) {
MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
return FAILED;
}
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
SetOutputLayout(next_in_tensor_layout);
if (Init(nullptr) == FAILED) {
MS_LOG(ERROR) << "Failure:operator reshape init failed";
return FAILED;
}
SetCostForReshape(reshape_stra);
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
......@@ -56,6 +56,9 @@ class ReshapeInfo : public OperatorInfo {
void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; }
void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; }
Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int32_t out_index,
int32_t in_index, bool is_prev_param);
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
......
......@@ -999,18 +999,6 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
return false;
}
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) {
Shape shape = pre_out_tensor_info.shape();
Shape slice_shape = pre_out_tensor_info.slice_shape();
for (size_t i = 0; i < shape.size(); ++i) {
if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) {
MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator";
}
int32_t dim = (int32_t)(shape[i] / slice_shape[i]);
(*stra).push_back(dim);
}
}
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
for (auto node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
......@@ -1054,46 +1042,10 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
reshape_info->set_next_operator_name(next_operator_info->name());
reshape_info->set_next_operator_index(in_index);
}
for (auto pre_stra_cost : pre_stra_costs) {
std::vector<TensorInfo> pre_out_tensor_infos;
if (pre_node->isa<Parameter>()) {
pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
} else {
pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
}
if (pre_out_tensor_infos.size() <= IntToSize(out_index)) {
MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
}
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
reshape_info->SetInputLayout(pre_out_tensor_layout);
// infer pre_node output strategy from output_layout.
Dimensions stra;
InferStraByTensorInfo(pre_out_tensor_info, &stra);
std::vector<Dimensions> stra_inputs = {stra};
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
if (next_stra_costs.empty()) {
if (reshape_info->Init(nullptr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
}
// set cost for each input_layout and output_layout pairs.
reshape_info->SetCostForReshape(reshape_stra);
continue;
}
for (auto next_stra_cost : next_stra_costs) {
std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
if (next_in_tensor_infos.size() <= IntToSize(in_index)) {
MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
}
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
reshape_info->SetOutputLayout(next_in_tensor_layout);
if (reshape_info->Init(nullptr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
}
// set cost for each input_layout and output_layout pairs.
reshape_info->SetCostForReshape(reshape_stra);
}
bool is_prev_param = pre_node->isa<Parameter>();
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
SUCCESS) {
MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!";
}
}
}
......
......@@ -51,8 +51,6 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);
void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes);
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra);
Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
......
......@@ -46,6 +46,17 @@ class TensorInfo {
Shape shape() const { return shape_; }
void set_reduce_dim(const std::vector<int32_t> &dim) { reduce_dim_ = dim; }
std::vector<int32_t> reduce_dim() const { return reduce_dim_; }
Dimensions InferStrategy() const {
Dimensions stra;
for (size_t i = 0; i < shape_.size(); ++i) {
if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) {
return stra;
}
int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]);
stra.push_back(dim);
}
return stra;
}
private:
TensorLayout tensor_layout_;
......
......@@ -86,6 +86,7 @@ def test_reshape_auto_1():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)
def test_reshape_auto_2():
......@@ -112,6 +113,7 @@ def test_reshape_auto_2():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)
def test_reshape_auto_3():
......@@ -135,6 +137,7 @@ def test_reshape_auto_3():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)
def test_reshape_auto_4():
......@@ -159,6 +162,7 @@ def test_reshape_auto_4():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)
......@@ -208,6 +212,7 @@ def test_reshape_auto_5():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)
def test_reshape_auto_6():
......@@ -254,4 +259,5 @@ def test_reshape_auto_6():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册