From 050d713da13107e7ea68e7e28a62f6ccae1aaa18 Mon Sep 17 00:00:00 2001
From: yankai <yankai10@huawei.com>
Date: Thu, 6 Aug 2020 21:53:45 +0800
Subject: [PATCH] remove maketuple and getitem

---
 .../src/common/anf_exporter/anf_exporter.cc   | 205 +++++++++++++-----
 .../src/common/anf_exporter/anf_exporter.h    |   6 +-
 2 files changed, 155 insertions(+), 56 deletions(-)

diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc
index 031da4751..37f520086 100644
--- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc
@@ -1,6 +1,4 @@
 /**
- * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
- *
  * Copyright 2020 Huawei Technologies Co., Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,52 +15,114 @@
  */
 
 #include "src/common/anf_exporter/anf_exporter.h"
+
 #include <memory>
+#include <set>
+#include <string>
 #include <utility>
 #include <vector>
-#include <string>
+
 #include "abstract/abstract_value.h"
-#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
-#include "src/param_value_lite.h"
+#include "base/core_ops.h"
 #include "mindspore/core/ir/primitive.h"
+#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
 #include "src/ir/primitive_t_value.h"
-#include "base/core_ops.h"
 #include "src/ir/tensor.h"
+#include "src/param_value_lite.h"
 
 namespace mindspore::lite {
+std::set<std::string> RemoveNodeInAnfExporter{"tuple_getitem", "make_tuple"};
+
+void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
+  bool hasMakeTuple = false;
+  std::vector<AnfNodePtr> inputs;
+  inputs.clear();
+
+  inputs.emplace_back(cnode->input(0));
+  for (size_t i = 1; i < cnode->inputs().size(); ++i) {
+    AnfNodePtr inputNode = cnode->input(i);
+    if (!inputNode->isa<CNode>()) {
+      inputs.emplace_back(cnode->input(i));
+      continue;
+    }
+    auto makeTupleNode = utils::cast<CNodePtr>(inputNode);
+    if (IsPrimitiveCNode(makeTupleNode, prim::kPrimMakeTuple)) {
+      hasMakeTuple = true;
+      for (size_t j = 1; j < makeTupleNode->inputs().size(); ++j) {
+        inputs.emplace_back(makeTupleNode->input(j));
+      }
+    } else {
+      inputs.emplace_back(cnode->input(i));
+    }
+  }
+  if (hasMakeTuple) {
+    cnode->set_inputs(inputs);
+  }
+}
+
+bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
+  bool hasTupleGetItem = false;
+  std::vector<AnfNodePtr> inputs;
+  inputs.clear();
+  inputs.emplace_back(cnode->input(0));
+  for (size_t i = 1; i < cnode->inputs().size(); ++i) {
+    AnfNodePtr inputNode = cnode->input(i);
+    if (!inputNode->isa<CNode>()) {
+      inputs.emplace_back(cnode->input(i));
+      continue;
+    }
+    auto tupleGetItemNode = utils::cast<CNodePtr>(inputNode);
+    if (IsPrimitiveCNode(tupleGetItemNode, prim::kPrimTupleGetItem)) {
+      hasTupleGetItem = true;
+      inputs.emplace_back(tupleGetItemNode->input(1));
+      AnfNodePtr indexNode = tupleGetItemNode->input(2);
+      if (utils::isa<ValueNodePtr>(indexNode)) {
+        MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
+        return false;
+      }
+      ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode);
+      mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] =
+          GetValue<int>(valueNode->value());
+    } else {
+      inputs.emplace_back(cnode->input(i));
+    }
+  }
+  if (hasTupleGetItem) {
+    cnode->set_inputs(inputs);
+  }
+  return true;
+}
+
+bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode) {
+  for (size_t i = 1; i < cnode->inputs().size(); ++i) {
+    auto inputNode = cnode->input(i);
+    if (!inputNode->isa<CNode>()) {
+      MS_LOG(ERROR) << "Node of Return's input is not CNode";
+      return false;
+    }
+    auto inputCNode = utils::cast<CNodePtr>(inputNode);
+    auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0));
+    std::string inputName = inputNode->fullname_with_scope();
+    auto graphOutput = nodeIdMap[inputName];
+    metaGraphT->outputIndex.emplace_back(graphOutput);
+  }
+  return true;
+}
+
 schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
   auto cnodes = funcGraph->GetOrderedCnodes();
   auto metaGraphT = std::make_unique<schema::MetaGraphT>();
   for (const auto &cnode : cnodes) {
     auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
-    if (primitive != nullptr && primitive == prim::kPrimReturn) {
-      // set graph outputs tensors
-      auto inputNode = cnode->input(1);
-      if (!inputNode->isa<CNode>()) {
-        continue;
-      }
-      auto inputCNode = utils::cast<CNodePtr>(inputNode);
-      auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0));
-      if (inputPrimitive == prim::kPrimMakeTuple) {
-        continue;
-      } else {
-        std::string inputName = inputNode->fullname_with_scope();
-        auto graphOutput = nodeIdMap[inputName];
-        metaGraphT->outputIndex.emplace_back(graphOutput);
-      }
+    if (primitive != nullptr &&
+        RemoveNodeInAnfExporter.count(primitive->name()) != 0) {
       continue;
     }
-    if (primitive != nullptr && primitive == prim::kPrimMakeTuple) {
-      for (size_t i = 1; i < cnode->inputs().size(); i++) {
-        auto graphOutNode = cnode->input(i);
-        if (!graphOutNode->isa<CNode>()) {
-          MS_LOG(ERROR) << "Inputs of MakeTuple should be cNode";
-          return nullptr;
-        }
-        std::string graphOutNodeName = graphOutNode->fullname_with_scope();
-        auto graphOutIndex = nodeIdMap[graphOutNodeName];
-        metaGraphT->outputIndex.emplace_back(graphOutIndex);
-      }
+    mapRemoveGetItem_.clear();
+    RemoveIfMakeTuple(cnode);
+    RemoveIfTupleGetItem(cnode);
+    if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) {
+      AddOutPutIfReturn(metaGraphT, cnode);
       continue;
     }
 
@@ -74,19 +134,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
       primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
       MS_ASSERT(primitive != nullptr);
       std::string opType = primitive->name();
-      auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
+      auto nodeParser =
+          AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
       if (nodeParser == nullptr) {
         MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
         return nullptr;
       }
       std::vector<schema::TensorT *> outputs;
+      if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
+        auto abstract_cnode =
+            utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
+        outputs.resize(abstract_cnode->size());
+      }
+
       nodeParser->Parse(cnode, node.get(), &outputs);
       SetOpInputNode(cnode, metaGraphT.get(), node.get());
       SetOpOutputNode(outputs, metaGraphT.get(), node.get());
       metaGraphT->nodes.emplace_back(std::move(node));
       continue;
     }
-    auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
+    auto primitiveT_value =
+        GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
     if (primitiveT_value == nullptr) {
       MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
       return nullptr;
@@ -98,7 +166,8 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
       return nullptr;
     }
 
-    node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
+    node->primitive =
+        std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
     std::vector<schema::TensorT *> outputs;
     SetOpInputNode(cnode, metaGraphT.get(), node.get());
     SetOpOutputNode(outputs, metaGraphT.get(), node.get());
@@ -112,10 +181,11 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
       auto tensor_input = metaGraphT->allTensors[activate_index].get();
       auto input_quant_params = primitiveT_value->GetInputQuantParams();
       if (input_quant_params.empty()) {
-        MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty";
+        MS_LOG(WARNING) << "node: " << node->name
+                        << " input quant params is empty";
       } else {
         std::unique_ptr<schema::QuantParamT> input_quant_param =
-          std::make_unique<schema::QuantParamT>(input_quant_params[0]);
+            std::make_unique<schema::QuantParamT>(input_quant_params[0]);
         tensor_input->quantParams.emplace_back(std::move(input_quant_param));
       }
       tensor_input->dataType = kNumberTypeInt8;
@@ -124,18 +194,20 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
       auto tensor_output = metaGraphT->allTensors[output_index].get();
       auto output_quant_params = primitiveT_value->GetOutputQuantParams();
       if (output_quant_params.empty()) {
-        MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
+        MS_LOG(WARNING) << "node: " << node->name
+                        << " output quant params is empty";
       } else {
         std::unique_ptr<schema::QuantParamT> output_quant_param =
-          std::make_unique<schema::QuantParamT>(output_quant_params[0]);
+            std::make_unique<schema::QuantParamT>(output_quant_params[0]);
         tensor_output->quantParams.emplace_back(std::move(output_quant_param));
       }
       tensor_output->dataType = kNumberTypeInt8;
       //      // TensorType
       //      valuePtr = primitive->GetAttr(kInputTensorDataType);
       //      if (valuePtr != nullptr) {
-      //        MS_LOG(INFO) << "node: " << node->name << " input tensor data type: " << GetValue<int>(valuePtr);
-      //        for (auto input : node->inputIndex) {
+      //        MS_LOG(INFO) << "node: " << node->name << " input tensor data
+      //        type: " << GetValue<int>(valuePtr); for (auto input :
+      //        node->inputIndex) {
       //          auto tensor = subGraph->allTensors[input].get();
       //          tensor->dataType = kNumberTypeUInt8;
       //        }
@@ -159,7 +231,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
   return metaGraphT.release();
 }
 
-void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) {
+void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
+                                 schema::MetaGraphT *meta_graph,
+                                 schema::CNodeT *fbNode) {
   MS_ASSERT(nullptr != meta_graph);
   MS_ASSERT(nullptr != fbNode);
   if (cnode->inputs().size() <= 1) {
@@ -172,6 +246,13 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
     if (inputNode->isa<CNode>()) {
       isGraphInput = false;
       std::string inputName = inputNode->fullname_with_scope();
+      if (!mapRemoveGetItem_.empty()) {
+        for (auto name : mapRemoveGetItem_) {
+          if (name.first == inputName) {
+            inputName = inputName + "_o:" + std::to_string(name.second);
+          }
+        }
+      }
       if (nodeIdMap.find(inputName) != nodeIdMap.end()) {
         fbNode->inputIndex.emplace_back(nodeIdMap[inputName]);
       }
@@ -187,30 +268,38 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
       auto paramTensor = std::make_unique<schema::TensorT>();
       auto abstractBase = paramNode->abstract();
       if (abstractBase == nullptr) {
-        MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
+        MS_LOG(ERROR) << "Abstract of parameter is nullptr, "
+                      << paramNode->name();
         MS_ASSERT(false);
         return;
       }
       if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
-        MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name();
+        MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, "
+                      << paramNode->name();
         MS_ASSERT(false);
         return;
       }
-      auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
+      auto abstractTensor =
+          utils::cast<abstract::AbstractTensorPtr>(abstractBase);
       auto typePtr = abstractTensor->element()->GetTypeTrack();
       MS_ASSERT(typePtr != nullptr);
       paramTensor->dataType = typePtr->type_id();
       if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
-        MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
+        MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, "
+                      << paramNode->name();
         MS_ASSERT(false);
         return;
       }
-      paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
-      auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
+      paramTensor->dims =
+          utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
+              ->shape();
+      auto paramValue =
+          std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
       if (paramValue != nullptr) {
         paramTensor->nodeType = schema::NodeType_ValueNode;
         paramTensor->data.resize(paramValue->tensor_size());
-        memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
+        memcpy(paramTensor->data.data(), paramValue->tensor_addr(),
+               paramValue->tensor_size());
       }
       for (auto &ite : paramValue->quant_param()) {
         auto quantPar = std::make_unique<schema::QuantParamT>();
@@ -224,7 +313,8 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
         paramTensor->quantParams.emplace_back(std::move(quantPar));
         paramTensor->dataType = paramValue->tensor_type();
       }
-      nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
+      nodeIdMap[paramNode->fullname_with_scope()] =
+          meta_graph->allTensors.size();
       fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
       meta_graph->allTensors.emplace_back(std::move(paramTensor));
     } else if (inputNode->isa<ValueNode>()) {
@@ -233,15 +323,19 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
       auto value = valueNode->value();
       if (value->isa<lite::tensor::Tensor>()) {
         auto valueAbstract = valueNode->abstract();
-        auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
+        auto abstractTensor =
+            utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
         auto typePtr = abstractTensor->element()->GetTypeTrack();
         paramTensor->dataType = typePtr->type_id();
-        paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
+        paramTensor->dims =
+            utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
+                ->shape();
         paramTensor->nodeType = schema::NodeType_ValueNode;
         auto data = value->cast<lite::tensor::TensorPtr>();
         paramTensor->data.resize(data->Size());
         memcpy(paramTensor->data.data(), data->Data(), data->Size());
-        nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
+        nodeIdMap[valueNode->fullname_with_scope()] =
+            meta_graph->allTensors.size();
         fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
         meta_graph->allTensors.emplace_back(std::move(paramTensor));
       } else if (value->isa<mindspore::ValueSequeue>()) {
@@ -257,8 +351,9 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
   }
 }
 
-void AnfExporter::SetOpOutputNode(const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph,
-                                  schema::CNodeT *cnode) {
+void AnfExporter::SetOpOutputNode(
+    const std::vector<schema::TensorT *> &outputTensors,
+    schema::MetaGraphT *graph, schema::CNodeT *cnode) {
   MS_ASSERT(nullptr != graph);
   MS_ASSERT(nullptr != cnode);
   std::string cnodeName = cnode->name;
diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.h b/mindspore/lite/src/common/anf_exporter/anf_exporter.h
index 48d52fd43..8cb04e9d7 100644
--- a/mindspore/lite/src/common/anf_exporter/anf_exporter.h
+++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.h
@@ -22,6 +22,7 @@
 #include <map>
 #include <string>
 #include <vector>
+#include <memory>
 #include "schema/inner/model_generated.h"
 #include "ir/func_graph.h"
 
@@ -34,10 +35,13 @@ class AnfExporter {
   void SetOpOutputNode(const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph,
                        schema::CNodeT *cnode);
   void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode);
-
+  void RemoveIfMakeTuple(const CNodePtr &cnode);
+  bool RemoveIfTupleGetItem(const CNodePtr &cnode);
+  bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode);
  private:
   std::map<std::string, int> nodeIdMap;
   std::vector<schema::CNodeT *> graphInputNodes;
+  std::map<std::string, int>  mapRemoveGetItem_;
 };
 
 schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph);
-- 
GitLab