提交 46acf238 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!405 [AutoParallel] Adapte rec-prog generator to new parser

Merge pull request !405 from Chong/str-gen
...@@ -27,44 +27,27 @@ ...@@ -27,44 +27,27 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> graph, std::vector<std::shared_ptr<OperatorInfo>> ops, void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
const std::shared_ptr<std::vector<size_t>> ops_nodes_list, const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
const std::shared_ptr<std::vector<size_t>> index_list, MS_EXCEPTION_IF_NULL(graph);
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list) { if (mask_special_ops) {
MaskNoSupportedOps(graph); MaskSpecialOps(graph);
}
for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
auto type = ops[iter_ops]->type();
size_t iter_nodes = index_list->at(ops_nodes_list->at(iter_ops));
std::vector<std::vector<int32_t>> stra; std::vector<std::vector<int32_t>> stra;
iter_nodes = IterNodes(ops_nodes_list, index_list, eli_list, iter_ops, iter_nodes);
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
std::vector<int32_t> s = PrepareStrategy(graph, ops, type, iter_ops, iter_nodes, iter_op_inputs); stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs));
stra.push_back(s);
} }
StrategyPtr sp = std::make_shared<Strategy>(0, stra); StrategyPtr sp = std::make_shared<Strategy>(0, stra);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
} }
} }
size_t IterNodes(const std::shared_ptr<std::vector<size_t>> ops_nodes_list, std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<size_t>> index_list, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const size_t iter_ops, const size_t iter_op_inputs) {
size_t iter_nodes) { std::vector<int32_t> s;
if (iter_nodes > SIZE_MAX / 2) { auto attrs = ops[iter_nodes]->attrs();
for (size_t iter_eli = 0; iter_eli < eli_list->size(); iter_eli++) {
if (eli_list->at(iter_eli)[0] == ops_nodes_list->at(iter_ops)) {
iter_nodes = index_list->at(eli_list->at(iter_eli)[1]);
break;
}
}
}
return iter_nodes;
}
void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s) {
auto attrs = ops[iter_ops]->attrs();
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value(); bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
if (transpose_a && (iter_op_inputs == 0)) { if (transpose_a && (iter_op_inputs == 0)) {
...@@ -77,10 +60,12 @@ void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::sh ...@@ -77,10 +60,12 @@ void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::sh
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} }
return s;
} }
void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, size_t iter_op_inputs, std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
std::vector<int32_t> s) { size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) { if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
...@@ -92,20 +77,24 @@ void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, ...@@ -92,20 +77,24 @@ void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes,
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
} }
return s;
} }
void PrepareBiasAdd(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs, std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
std::vector<int32_t> s) { const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) { if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else { } else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} }
return s;
} }
void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs, std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
std::vector<int32_t> s) { const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) { if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
...@@ -114,97 +103,133 @@ void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, cons ...@@ -114,97 +103,133 @@ void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, cons
} else { } else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w)); s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
} }
return s;
} }
void PrepareSparse(const size_t iter_op_inputs, std::vector<int32_t> s) { std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) { if (iter_op_inputs == 0) {
s.push_back(g_device_manager->DeviceNum()); s.push_back(g_device_manager->DeviceNum());
s.push_back(1); s.push_back(1);
} else { } else {
s.push_back(g_device_manager->DeviceNum()); s.push_back(g_device_manager->DeviceNum());
} }
return s;
}
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size())
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
s.push_back(1);
}
return s;
} }
void RefillOrigin(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs, std::vector<int32_t> s) { const size_t iter_op_inputs) {
std::vector<int32_t> s;
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
return s;
}
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); StrategyPtr origin_strategy = ops[iter_ops]->strategy();
if (iter_op_inputs == 0) { if (iter_op_inputs >= origin_strategy->GetInputDim().size())
for (size_t j = 0; j < origin_strategy->GetInputDim()[0].size(); j++) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
s.push_back(1); size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
} for (size_t dim = 0; dim < input_size; dim++) {
} else { if (dim == 0 && input_size == 4) {
for (size_t k = 0; k < origin_strategy->GetInputDim()[iter_op_inputs].size(); k++) { size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
s.push_back(std::min(max_device_num, target_tensor_batch));
} else {
s.push_back(1); s.push_back(1);
} }
} }
return s;
} }
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> graph, std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::string &type, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs) { const size_t iter_op_inputs) {
std::vector<int32_t> s; if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
auto type = ops[iter_ops]->type();
if (type == MATMUL) { if (type == MATMUL) {
PrepareMatMul(graph, ops, iter_ops, iter_nodes, iter_op_inputs, s); return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs);
} else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) { } else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_n)); return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (type == CONV2D) { } else if (type == CONV2D) {
PrepareConv2D(graph, iter_nodes, iter_op_inputs, s); return PrepareConv2D(graph, iter_ops, iter_op_inputs);
} else if (type == BIAS_ADD) { } else if (type == BIAS_ADD) {
PrepareBiasAdd(graph, iter_nodes, iter_op_inputs, s); return PrepareBiasAdd(graph, iter_ops, iter_op_inputs);
} else if (type == RESHAPE) { } else if (type == RESHAPE) {
s.push_back(1); return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs);
s.push_back(1);
s.push_back(1);
s.push_back(1);
} else if (type == RELU) { } else if (type == RELU) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_n)); return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_w));
} else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) { } else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) {
PrepareBN(graph, iter_nodes, iter_op_inputs, s); return PrepareBN(graph, iter_ops, iter_op_inputs);
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { } else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
PrepareSparse(iter_op_inputs, s); return PrepareSparse(iter_op_inputs);
} else { } else {
RefillOrigin(ops, iter_ops, iter_op_inputs, s); return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
} }
return s;
} }
void MaskNoSupportedOps(const std::shared_ptr<Graph> graph) { void MaskSpecialOps(std::shared_ptr<Graph> graph) {
size_t iter_nodes = graph->nodes.size(); size_t iter_nodes = graph->nodes.size();
for (size_t i = 0; i < iter_nodes; i++) { for (size_t i = 0; i < iter_nodes; i++) {
if (0 == graph->nodes[i].info) { Graph::NodeType &node = graph->nodes[i];
Graph::NodeType &node = graph->nodes[i];
if (node.apply.op_type == 1) { // For Convolution if (node.apply.op_type == 1) { // For Convolution
// cover input tensor strategy // cover input tensor strategy
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum()); node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1; node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1; node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1; node.apply.arguments[0].tensor_str.str_w = 1;
// cover filter tensor strategy // cover filter tensor strategy
node.apply.arguments[1].tensor_str.str_n = 1; node.apply.arguments[1].tensor_str.str_n = 1;
node.apply.arguments[1].tensor_str.str_c = 1; node.apply.arguments[1].tensor_str.str_c = 1;
node.apply.arguments[1].tensor_str.str_h = 1; node.apply.arguments[1].tensor_str.str_h = 1;
node.apply.arguments[1].tensor_str.str_w = 1; node.apply.arguments[1].tensor_str.str_w = 1;
} else if (node.apply.op_type == 8) { // For BN } else if (node.apply.op_type == 8) { // For BN
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum()); node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1; node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1; node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1; node.apply.arguments[0].tensor_str.str_w = 1;
// cover 1-d argument blobs // cover 1-d argument blobs
node.apply.arguments[1].tensor_str.str_w = 1; node.apply.arguments[1].tensor_str.str_n = 1;
node.apply.arguments[2].tensor_str.str_w = 1; node.apply.arguments[2].tensor_str.str_c = 1;
node.apply.arguments[3].tensor_str.str_w = 1; node.apply.arguments[3].tensor_str.str_h = 1;
node.apply.arguments[4].tensor_str.str_w = 1; node.apply.arguments[4].tensor_str.str_w = 1;
} else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits } else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits
node.tensor_parm.tensor_str.str_h = 1.0 / static_cast<float>(g_device_manager->DeviceNum()); node.tensor_parm.tensor_str.str_h = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.tensor_parm.tensor_str.str_w = 1; node.tensor_parm.tensor_str.str_w = 1;
}
} }
} }
} }
......
...@@ -27,29 +27,28 @@ ...@@ -27,29 +27,28 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> graph, std::vector<std::shared_ptr<OperatorInfo>> ops, void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
const std::shared_ptr<std::vector<size_t>> ops_nodes_list, const std::vector<std::shared_ptr<OperatorInfo>> &ops);
const std::shared_ptr<std::vector<size_t>> index_list, std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list); const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_op_inputs);
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, std::vector<int32_t> s); std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs, const size_t iter_op_inputs);
std::vector<int32_t> s); std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
void PrepareBiasAdd(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs, const size_t iter_op_inputs);
std::vector<int32_t> s); std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs, const size_t iter_op_inputs);
std::vector<int32_t> s); std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs);
void PrepareSparse(const size_t iter_op_inputs, std::vector<int32_t> s); std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
void RefillOrigin(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, const size_t iter_op_inputs);
const size_t iter_op_inputs, std::vector<int32_t> s); std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> graph, const size_t iter_op_inputs);
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::string &type, std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs); const size_t iter_ops, const size_t iter_op_inputs);
size_t IterNodes(const std::shared_ptr<std::vector<size_t>> ops_nodes_list, std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<size_t>> index_list, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const size_t iter_ops, const size_t iter_op_inputs);
size_t iter_nodes); void MaskSpecialOps(std::shared_ptr<Graph> graph);
void MaskNoSupportedOps(const std::shared_ptr<Graph> graph);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
...@@ -931,8 +931,6 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const ...@@ -931,8 +931,6 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
} }
std::shared_ptr<std::vector<size_t>> ops_nodes_list(new std::vector<size_t>); std::shared_ptr<std::vector<size_t>> ops_nodes_list(new std::vector<size_t>);
std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names); std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
...@@ -944,7 +942,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const ...@@ -944,7 +942,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
return FAILED; return FAILED;
} }
GenerateStrategy(graph, ops, ops_nodes_list, index_list, eli_list); bool mask_special_ops = true;
GenerateStrategy(graph, mask_special_ops, ops);
if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
MS_LOG(INFO) << "Init selected strategy succeeded."; MS_LOG(INFO) << "Init selected strategy succeeded.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册