提交 f1254ed3 编写于 作者: Y yankai

fix export at outputtensor

上级 a3959071
...@@ -76,7 +76,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { ...@@ -76,7 +76,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
hasTupleGetItem = true; hasTupleGetItem = true;
inputs.emplace_back(tupleGetItemNode->input(1)); inputs.emplace_back(tupleGetItemNode->input(1));
AnfNodePtr indexNode = tupleGetItemNode->input(2); AnfNodePtr indexNode = tupleGetItemNode->input(2);
if (utils::isa<ValueNodePtr>(indexNode)) { if (!utils::isa<ValueNode>(indexNode)) {
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
return false; return false;
} }
...@@ -300,7 +300,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, ...@@ -300,7 +300,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
paramTensor->data.resize(paramValue->tensor_size()); paramTensor->data.resize(paramValue->tensor_size());
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), memcpy(paramTensor->data.data(), paramValue->tensor_addr(),
paramValue->tensor_size()); paramValue->tensor_size());
}
for (auto &ite : paramValue->quant_param()) { for (auto &ite : paramValue->quant_param()) {
auto quantPar = std::make_unique<schema::QuantParamT>(); auto quantPar = std::make_unique<schema::QuantParamT>();
quantPar->scale = ite->scale; quantPar->scale = ite->scale;
...@@ -313,6 +312,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, ...@@ -313,6 +312,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
paramTensor->quantParams.emplace_back(std::move(quantPar)); paramTensor->quantParams.emplace_back(std::move(quantPar));
paramTensor->dataType = paramValue->tensor_type(); paramTensor->dataType = paramValue->tensor_type();
} }
}
nodeIdMap[paramNode->fullname_with_scope()] = nodeIdMap[paramNode->fullname_with_scope()] =
meta_graph->allTensors.size(); meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
...@@ -373,9 +373,11 @@ void AnfExporter::SetOpOutputNode( ...@@ -373,9 +373,11 @@ void AnfExporter::SetOpOutputNode(
int i = 0; int i = 0;
for (auto outputTensor : outputTensors) { for (auto outputTensor : outputTensors) {
std::string name = cnodeName + "_o:" + std::to_string(i); std::string name = cnodeName + "_o:" + std::to_string(i);
auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter;
nodeIdMap[name] = graph->allTensors.size(); nodeIdMap[name] = graph->allTensors.size();
cnode->outputIndex.emplace_back(graph->allTensors.size()); cnode->outputIndex.emplace_back(graph->allTensors.size());
graph->allTensors.emplace_back(outputTensor); graph->allTensors.emplace_back(msTensor);
i++; i++;
} }
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册