From 050d713da13107e7ea68e7e28a62f6ccae1aaa18 Mon Sep 17 00:00:00 2001 From: yankai <yankai10@huawei.com> Date: Thu, 6 Aug 2020 21:53:45 +0800 Subject: [PATCH] remove maketuple and getitem --- .../src/common/anf_exporter/anf_exporter.cc | 205 +++++++++++++----- .../src/common/anf_exporter/anf_exporter.h | 6 +- 2 files changed, 155 insertions(+), 56 deletions(-) diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index 031da4751..37f520086 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -1,6 +1,4 @@ /** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,52 +15,114 @@ */ #include "src/common/anf_exporter/anf_exporter.h" + #include <memory> +#include <set> +#include <string> #include <utility> #include <vector> -#include <string> + #include "abstract/abstract_value.h" -#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" -#include "src/param_value_lite.h" +#include "base/core_ops.h" #include "mindspore/core/ir/primitive.h" +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/ir/primitive_t_value.h" -#include "base/core_ops.h" #include "src/ir/tensor.h" +#include "src/param_value_lite.h" namespace mindspore::lite { +std::set<std::string> RemoveNodeInAnfExporter{"tuple_getitem", "make_tuple"}; + +void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { + bool hasMakeTuple = false; + std::vector<AnfNodePtr> inputs; + inputs.clear(); + + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa<CNode>()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto makeTupleNode = utils::cast<CNodePtr>(inputNode); + if (IsPrimitiveCNode(makeTupleNode, prim::kPrimMakeTuple)) { + hasMakeTuple = true; + for (size_t j = 1; j < makeTupleNode->inputs().size(); ++j) { + inputs.emplace_back(makeTupleNode->input(j)); + } + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (hasMakeTuple) { + cnode->set_inputs(inputs); + } +} + +bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { + bool hasTupleGetItem = false; + std::vector<AnfNodePtr> inputs; + inputs.clear(); + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa<CNode>()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto tupleGetItemNode = utils::cast<CNodePtr>(inputNode); + if (IsPrimitiveCNode(tupleGetItemNode, prim::kPrimTupleGetItem)) { + hasTupleGetItem = true; + inputs.emplace_back(tupleGetItemNode->input(1)); + AnfNodePtr indexNode = tupleGetItemNode->input(2); + if (utils::isa<ValueNodePtr>(indexNode)) { + MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; + return false; + } + ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode); + mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = + GetValue<int>(valueNode->value()); + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (hasTupleGetItem) { + cnode->set_inputs(inputs); + } + return true; +} + +bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto inputNode = cnode->input(i); + if (!inputNode->isa<CNode>()) { + MS_LOG(ERROR) << "Node of Return's input is not CNode"; + return false; + } + auto inputCNode = utils::cast<CNodePtr>(inputNode); + auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0)); + std::string inputName = inputNode->fullname_with_scope(); + auto graphOutput = nodeIdMap[inputName]; + metaGraphT->outputIndex.emplace_back(graphOutput); + } + return true; +} + schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto cnodes = funcGraph->GetOrderedCnodes(); auto metaGraphT = std::make_unique<schema::MetaGraphT>(); for (const auto &cnode : cnodes) { auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); - if (primitive != nullptr && primitive == prim::kPrimReturn) { - // set graph outputs tensors - auto inputNode = cnode->input(1); - if (!inputNode->isa<CNode>()) { - continue; - } - auto inputCNode = utils::cast<CNodePtr>(inputNode); - auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0)); - if (inputPrimitive == prim::kPrimMakeTuple) { - continue; - } else { - std::string inputName = inputNode->fullname_with_scope(); - auto graphOutput = nodeIdMap[inputName]; - metaGraphT->outputIndex.emplace_back(graphOutput); - } + if (primitive != nullptr && + RemoveNodeInAnfExporter.count(primitive->name()) != 0) { continue; } - if (primitive != nullptr && primitive == prim::kPrimMakeTuple) { - for (size_t i = 1; i < cnode->inputs().size(); i++) { - auto graphOutNode = cnode->input(i); - if (!graphOutNode->isa<CNode>()) { - MS_LOG(ERROR) << "Inputs of MakeTuple should be cNode"; - return nullptr; - } - std::string graphOutNodeName = graphOutNode->fullname_with_scope(); - auto graphOutIndex = nodeIdMap[graphOutNodeName]; - metaGraphT->outputIndex.emplace_back(graphOutIndex); - } + mapRemoveGetItem_.clear(); + RemoveIfMakeTuple(cnode); + RemoveIfTupleGetItem(cnode); + if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) { + AddOutPutIfReturn(metaGraphT, cnode); continue; } @@ -74,19 +134,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); MS_ASSERT(primitive != nullptr); std::string opType = primitive->name(); - auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + auto nodeParser = + AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); if (nodeParser == nullptr) { MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; return nullptr; } std::vector<schema::TensorT *> outputs; + if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { + auto abstract_cnode = + utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); + outputs.resize(abstract_cnode->size()); + } + nodeParser->Parse(cnode, node.get(), &outputs); SetOpInputNode(cnode, metaGraphT.get(), node.get()); SetOpOutputNode(outputs, metaGraphT.get(), node.get()); metaGraphT->nodes.emplace_back(std::move(node)); continue; } - auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); + auto primitiveT_value = + GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; return nullptr; @@ -98,7 +166,8 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { return nullptr; } - node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); + node->primitive = + std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); std::vector<schema::TensorT *> outputs; SetOpInputNode(cnode, metaGraphT.get(), node.get()); SetOpOutputNode(outputs, metaGraphT.get(), node.get()); @@ -112,10 +181,11 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto tensor_input = metaGraphT->allTensors[activate_index].get(); auto input_quant_params = primitiveT_value->GetInputQuantParams(); if (input_quant_params.empty()) { - MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; + MS_LOG(WARNING) << "node: " << node->name + << " input quant params is empty"; } else { std::unique_ptr<schema::QuantParamT> input_quant_param = - std::make_unique<schema::QuantParamT>(input_quant_params[0]); + std::make_unique<schema::QuantParamT>(input_quant_params[0]); tensor_input->quantParams.emplace_back(std::move(input_quant_param)); } tensor_input->dataType = kNumberTypeInt8; @@ -124,18 +194,20 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto tensor_output = metaGraphT->allTensors[output_index].get(); auto output_quant_params = primitiveT_value->GetOutputQuantParams(); if (output_quant_params.empty()) { - MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; + MS_LOG(WARNING) << "node: " << node->name + << " output quant params is empty"; } else { std::unique_ptr<schema::QuantParamT> output_quant_param = - std::make_unique<schema::QuantParamT>(output_quant_params[0]); + std::make_unique<schema::QuantParamT>(output_quant_params[0]); tensor_output->quantParams.emplace_back(std::move(output_quant_param)); } tensor_output->dataType = kNumberTypeInt8; // // TensorType // valuePtr = primitive->GetAttr(kInputTensorDataType); // if (valuePtr != nullptr) { - // MS_LOG(INFO) << "node: " << node->name << " input tensor data type: " << GetValue<int>(valuePtr); - // for (auto input : node->inputIndex) { + // MS_LOG(INFO) << "node: " << node->name << " input tensor data + // type: " << GetValue<int>(valuePtr); for (auto input : + // node->inputIndex) { // auto tensor = subGraph->allTensors[input].get(); // tensor->dataType = kNumberTypeUInt8; // } @@ -159,7 +231,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { return metaGraphT.release(); } -void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) { +void AnfExporter::SetOpInputNode(const CNodePtr &cnode, + schema::MetaGraphT *meta_graph, + schema::CNodeT *fbNode) { MS_ASSERT(nullptr != meta_graph); MS_ASSERT(nullptr != fbNode); if (cnode->inputs().size() <= 1) { @@ -172,6 +246,13 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta if (inputNode->isa<CNode>()) { isGraphInput = false; std::string inputName = inputNode->fullname_with_scope(); + if (!mapRemoveGetItem_.empty()) { + for (auto name : mapRemoveGetItem_) { + if (name.first == inputName) { + inputName = inputName + "_o:" + std::to_string(name.second); + } + } + } if (nodeIdMap.find(inputName) != nodeIdMap.end()) { fbNode->inputIndex.emplace_back(nodeIdMap[inputName]); } @@ -187,30 +268,38 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta auto paramTensor = std::make_unique<schema::TensorT>(); auto abstractBase = paramNode->abstract(); if (abstractBase == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " + << paramNode->name(); MS_ASSERT(false); return; } if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " + << paramNode->name(); MS_ASSERT(false); return; } - auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); + auto abstractTensor = + utils::cast<abstract::AbstractTensorPtr>(abstractBase); auto typePtr = abstractTensor->element()->GetTypeTrack(); MS_ASSERT(typePtr != nullptr); paramTensor->dataType = typePtr->type_id(); if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " + << paramNode->name(); MS_ASSERT(false); return; } - paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); - auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); + paramTensor->dims = + utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) + ->shape(); + auto paramValue = + std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); if (paramValue != nullptr) { paramTensor->nodeType = schema::NodeType_ValueNode; paramTensor->data.resize(paramValue->tensor_size()); - memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); + memcpy(paramTensor->data.data(), paramValue->tensor_addr(), + paramValue->tensor_size()); } for (auto &ite : paramValue->quant_param()) { auto quantPar = std::make_unique<schema::QuantParamT>(); @@ -224,7 +313,8 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta paramTensor->quantParams.emplace_back(std::move(quantPar)); paramTensor->dataType = paramValue->tensor_type(); } - nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); + nodeIdMap[paramNode->fullname_with_scope()] = + meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); meta_graph->allTensors.emplace_back(std::move(paramTensor)); } else if (inputNode->isa<ValueNode>()) { @@ -233,15 +323,19 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta auto value = valueNode->value(); if (value->isa<lite::tensor::Tensor>()) { auto valueAbstract = valueNode->abstract(); - auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); + auto abstractTensor = + utils::cast<abstract::AbstractTensorPtr>(valueAbstract); auto typePtr = abstractTensor->element()->GetTypeTrack(); paramTensor->dataType = typePtr->type_id(); - paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); + paramTensor->dims = + utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) + ->shape(); paramTensor->nodeType = schema::NodeType_ValueNode; auto data = value->cast<lite::tensor::TensorPtr>(); paramTensor->data.resize(data->Size()); memcpy(paramTensor->data.data(), data->Data(), data->Size()); - nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); + nodeIdMap[valueNode->fullname_with_scope()] = + meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); meta_graph->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa<mindspore::ValueSequeue>()) { @@ -257,8 +351,9 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta } } -void AnfExporter::SetOpOutputNode(const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph, - schema::CNodeT *cnode) { +void AnfExporter::SetOpOutputNode( + const std::vector<schema::TensorT *> &outputTensors, + schema::MetaGraphT *graph, schema::CNodeT *cnode) { MS_ASSERT(nullptr != graph); MS_ASSERT(nullptr != cnode); std::string cnodeName = cnode->name; diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.h b/mindspore/lite/src/common/anf_exporter/anf_exporter.h index 48d52fd43..8cb04e9d7 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.h +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.h @@ -22,6 +22,7 @@ #include <map> #include <string> #include <vector> +#include <memory> #include "schema/inner/model_generated.h" #include "ir/func_graph.h" @@ -34,10 +35,13 @@ class AnfExporter { void SetOpOutputNode(const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph, schema::CNodeT *cnode); void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); - + void RemoveIfMakeTuple(const CNodePtr &cnode); + bool RemoveIfTupleGetItem(const CNodePtr &cnode); + bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode); private: std::map<std::string, int> nodeIdMap; std::vector<schema::CNodeT *> graphInputNodes; + std::map<std::string, int> mapRemoveGetItem_; }; schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); -- GitLab