diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index ddc58bc3b548ad86cb6f60e95d6c9982dd0977f8..ec708411a2533e06d98392f1432319e3b7e87fa0 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -76,7 +76,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { hasTupleGetItem = true; inputs.emplace_back(tupleGetItemNode->input(1)); AnfNodePtr indexNode = tupleGetItemNode->input(2); - if (utils::isa(indexNode)) { + if (!utils::isa(indexNode)) { MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; return false; } @@ -300,18 +300,18 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, paramTensor->data.resize(paramValue->tensor_size()); memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); - } - for (auto &ite : paramValue->quant_param()) { - auto quantPar = std::make_unique(); - quantPar->scale = ite->scale; - quantPar->zeroPoint = ite->zeroPoint; - quantPar->min = ite->min; - quantPar->max = ite->max; - quantPar->narrowRange = ite->narrowRange; - quantPar->inited = ite->inited; - quantPar->numBits = ite->numBits; - paramTensor->quantParams.emplace_back(std::move(quantPar)); - paramTensor->dataType = paramValue->tensor_type(); + for (auto &ite : paramValue->quant_param()) { + auto quantPar = std::make_unique(); + quantPar->scale = ite->scale; + quantPar->zeroPoint = ite->zeroPoint; + quantPar->min = ite->min; + quantPar->max = ite->max; + quantPar->narrowRange = ite->narrowRange; + quantPar->inited = ite->inited; + quantPar->numBits = ite->numBits; + paramTensor->quantParams.emplace_back(std::move(quantPar)); + paramTensor->dataType = paramValue->tensor_type(); + } } nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); @@ -373,9 +373,11 @@ void AnfExporter::SetOpOutputNode( int i = 0; for (auto outputTensor : outputTensors) { std::string name = cnodeName + "_o:" + std::to_string(i); + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; nodeIdMap[name] = graph->allTensors.size(); cnode->outputIndex.emplace_back(graph->allTensors.size()); - graph->allTensors.emplace_back(outputTensor); + graph->allTensors.emplace_back(msTensor); i++; } return;