提交 66553ac3 编写于 作者: H hongxing

optimize code

上级 9febf7fd
......@@ -168,12 +168,11 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> strategies;
int32_t axis = 0;
auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2));
if (axis_input < 0) {
axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
}
axis = axis_input;
int32_t axis = axis_input;
if (axis >= SizeToInt(s.size())) {
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
}
......
......@@ -20,7 +20,6 @@
#include <memory>
#include <string>
#include <vector>
#include <set>
#include "ir/value.h"
#include "parallel/auto_parallel/rec_core/rec_graph.h"
......@@ -215,23 +214,16 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::shared_ptr<std::vector<size_t>> &index_list) {
MS_EXCEPTION_IF_NULL(graph);
static const std::set<OperatorType> elementwise_type = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type;
if (elementwise_type.find(type) != elementwise_type.end()) {
if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) {
Eliminate_Aux(node_index, graph, eli_list);
}
}
index_list->reserve(graph->nodes.size());
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
index_list->push_back(i);
}
for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
if (eli_list->at(i)[0] >= index_list->size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
......@@ -241,13 +233,11 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
index_list->at(j)--;
}
}
std::shared_ptr<Graph> new_graph(new Graph);
for (size_t i = 0; i < graph->nodes.size(); i++) {
if (index_list->at(i) > SIZE_MAX / 2) {
continue;
}
new_graph->nodes.push_back(graph->nodes[i]);
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
for (size_t j = node_in->size(); j > 0; j--) {
......
......@@ -22,12 +22,19 @@
#include <string>
#include <utility>
#include <vector>
#include <set>
#include "parallel/auto_parallel/rec_core/rec_graph.h"
#include "parallel/ops_info/operator_info.h"
namespace mindspore {
namespace parallel {
static const std::set<OperatorType> ElementWiseOpType = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
const std::map<std::string, OperatorType> DictOpType{
{MATMUL, OperatorType::kRecMatMul},
{CONV2D, OperatorType::kRecConvolution},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册