提交 39790ccf 编写于 作者: H hongxing

Optimize code

上级 bee57fda
......@@ -146,16 +146,6 @@ std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph>
return strategies;
}
std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = graph->nodes[iter_graph].tensor_parm.tensor_str.str_h;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_c;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = graph->nodes[iter_graph].tensor_parm.tensor_str.str_n;
return strategies;
}
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) {
std::vector<std::vector<int32_t>> strategies;
strategies.push_back(*s);
......@@ -299,9 +289,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
return PreparePReLU(graph, ops, iter_graph, iter_ops);
} else if (type == BATCH_NORM) {
return PrepareBatchNorm(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return PrepareSoftmaxWithLogits(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS ||
type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
......
......@@ -40,9 +40,6 @@ std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &gra
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s);
......
......@@ -69,6 +69,7 @@ class Graph {
std::vector<size_t> node_in;
// Nodes that point from this node
std::vector<size_t> node_out;
std::vector<size_t> node_in_aux;
// Node Type Info: Application or Constant. Defined in enum <InfoType> .
InfoType info;
// Operator info. Defined in struct <OperatorRec> .
......
......@@ -171,21 +171,41 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph,
eli.push_back(graph->nodes[node_index].node_out[i]);
}
eli_list->push_back(eli);
for (auto input_index : graph->nodes[node_index].node_in) {
auto it = find(graph->nodes[input_index].node_out.begin(), graph->nodes[input_index].node_out.end(), node_index);
if (it != graph->nodes[input_index].node_out.end()) {
graph->nodes[input_index].node_out.erase(it);
for (auto output_index : graph->nodes[node_index].node_out) {
graph->nodes[input_index].node_out.push_back(output_index);
}
for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) {
auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out;
auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index);
if (it != incoming_outputs->end()) {
it = incoming_outputs->erase(it);
incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end());
}
}
for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) {
auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out;
auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index);
if (it != aux_incoming_outputs->end()) {
it = aux_incoming_outputs->erase(it);
aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(),
graph->nodes[node_index].node_out.end());
}
}
for (auto output_index : graph->nodes[node_index].node_out) {
auto it = find(graph->nodes[output_index].node_in.begin(), graph->nodes[output_index].node_in.end(), node_index);
if (it != graph->nodes[output_index].node_in.end()) {
graph->nodes[output_index].node_in.erase(it);
for (auto input_index : graph->nodes[node_index].node_in) {
graph->nodes[output_index].node_in.push_back(input_index);
for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in;
auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index);
if (it != outgoing_inputs->end()) {
if (graph->nodes[node_index].node_in.size() > 0) {
outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0];
for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]);
}
for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) {
graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(
graph->nodes[node_index].node_in_aux[j]);
}
} else {
outgoing_inputs->erase(it);
}
}
}
......@@ -206,10 +226,12 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
Eliminate_Aux(node_index, graph, eli_list);
}
}
index_list->reserve(graph->nodes.size());
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
index_list->push_back(i);
}
for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
if (eli_list->at(i)[0] >= index_list->size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
......@@ -219,6 +241,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
index_list->at(j)--;
}
}
std::shared_ptr<Graph> new_graph(new Graph);
for (size_t i = 0; i < graph->nodes.size(); i++) {
if (index_list->at(i) > SIZE_MAX / 2) {
......@@ -226,11 +249,13 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
}
new_graph->nodes.push_back(graph->nodes[i]);
for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_in.size(); j++) {
new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]);
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
for (size_t j = 0; j < node_in->size(); j++) {
node_in->at(j) = index_list->at(node_in->at(j));
}
for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_out.size(); j++) {
new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]);
auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
for (size_t j = 0; j < node_out->size(); j++) {
node_out->at(j) = index_list->at(node_out->at(j));
}
}
return new_graph;
......
......@@ -59,6 +59,7 @@ const std::map<std::string, OperatorType> DictOpType{
{PRELU, OperatorType::kRecPReLU},
{L2_NORMALIZE, OperatorType::kRecElmWiseOp},
{TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp},
{MUL, OperatorType::kRecElmWiseOp},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册