提交 d4671497 编写于 作者: Y yeyunpeng

fix op multi output problem

上级 8e3c8f3d
...@@ -189,7 +189,9 @@ union PrimitiveType { ...@@ -189,7 +189,9 @@ union PrimitiveType {
ActivationGrad, ActivationGrad,
PriorBox, PriorBox,
SpaceToBatchND, SpaceToBatchND,
TopKV2 TopKV2,
Return,
MakeTuple
} }
enum QuantType: int { enum QuantType: int {
......
...@@ -864,3 +864,9 @@ table TopKV2 { ...@@ -864,3 +864,9 @@ table TopKV2 {
sorted : bool = true; sorted : bool = true;
} }
table MakeTuple {
}
table Return {
}
\ No newline at end of file
...@@ -81,8 +81,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { ...@@ -81,8 +81,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
return false; return false;
} }
ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode); ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode);
mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = GetValue<int>(valueNode->value());
GetValue<int>(valueNode->value());
} else { } else {
inputs.emplace_back(cnode->input(i)); inputs.emplace_back(cnode->input(i));
} }
...@@ -114,16 +113,34 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { ...@@ -114,16 +113,34 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto metaGraphT = std::make_unique<schema::MetaGraphT>(); auto metaGraphT = std::make_unique<schema::MetaGraphT>();
for (const auto &cnode : cnodes) { for (const auto &cnode : cnodes) {
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive != nullptr && if (primitive != nullptr) {
RemoveNodeInAnfExporter.count(primitive->name()) != 0) { if (RemoveNodeInAnfExporter.count(primitive->name()) != 0) {
continue; continue;
}
} else {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
auto primT = primitiveT_value->GetPrimitiveT();
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
primT->value.type == schema::PrimitiveType_MakeTuple) {
continue;
}
} }
mapRemoveGetItem_.clear(); mapRemoveGetItem_.clear();
RemoveIfMakeTuple(cnode); RemoveIfMakeTuple(cnode);
RemoveIfTupleGetItem(cnode); RemoveIfTupleGetItem(cnode);
if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) {
AddOutPutIfReturn(metaGraphT, cnode); if (primitive != nullptr) {
continue; if (primitive->name() == prim::kPrimReturn->name()) {
AddOutPutIfReturn(metaGraphT, cnode);
continue;
}
} else {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
auto primT = primitiveT_value->GetPrimitiveT();
if (primT->value.type == schema::PrimitiveType_Return) {
AddOutPutIfReturn(metaGraphT, cnode);
continue;
}
} }
auto node = std::make_unique<schema::CNodeT>(); auto node = std::make_unique<schema::CNodeT>();
...@@ -134,27 +151,24 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { ...@@ -134,27 +151,24 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
std::string opType = primitive->name(); std::string opType = primitive->name();
auto nodeParser = auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
if (nodeParser == nullptr) { if (nodeParser == nullptr) {
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
return nullptr; return nullptr;
} }
std::vector<schema::TensorT *> outputs; std::vector<schema::TensorT *> outputs;
if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
auto abstract_cnode = auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
outputs.resize(abstract_cnode->size()); outputs.resize(abstract_cnode->size());
} }
nodeParser->Parse(cnode, node.get(), &outputs); nodeParser->Parse(cnode, node.get(), &outputs);
SetOpInputNode(cnode, metaGraphT.get(), node.get()); SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(outputs, metaGraphT.get(), node.get()); SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
metaGraphT->nodes.emplace_back(std::move(node)); metaGraphT->nodes.emplace_back(std::move(node));
continue; continue;
} }
auto primitiveT_value = auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) { if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
return nullptr; return nullptr;
...@@ -166,11 +180,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { ...@@ -166,11 +180,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
return nullptr; return nullptr;
} }
node->primitive = node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
std::vector<schema::TensorT *> outputs; std::vector<schema::TensorT *> outputs;
SetOpInputNode(cnode, metaGraphT.get(), node.get()); SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(outputs, metaGraphT.get(), node.get()); SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
// add quant param // add quant param
node->quantType = primitiveT_value->GetQuantType(); node->quantType = primitiveT_value->GetQuantType();
...@@ -244,9 +257,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { ...@@ -244,9 +257,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
return metaGraphT.release(); return metaGraphT.release();
} }
void AnfExporter::SetOpInputNode(const CNodePtr &cnode, void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) {
schema::MetaGraphT *meta_graph,
schema::CNodeT *fbNode) {
MS_ASSERT(nullptr != meta_graph); MS_ASSERT(nullptr != meta_graph);
MS_ASSERT(nullptr != fbNode); MS_ASSERT(nullptr != fbNode);
if (cnode->inputs().size() <= 1) { if (cnode->inputs().size() <= 1) {
...@@ -281,38 +292,30 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, ...@@ -281,38 +292,30 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
auto paramTensor = std::make_unique<schema::TensorT>(); auto paramTensor = std::make_unique<schema::TensorT>();
auto abstractBase = paramNode->abstract(); auto abstractBase = paramNode->abstract();
if (abstractBase == nullptr) { if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
<< paramNode->name();
MS_ASSERT(false); MS_ASSERT(false);
return; return;
} }
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name();
<< paramNode->name();
MS_ASSERT(false); MS_ASSERT(false);
return; return;
} }
auto abstractTensor = auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
utils::cast<abstract::AbstractTensorPtr>(abstractBase);
auto typePtr = abstractTensor->element()->GetTypeTrack(); auto typePtr = abstractTensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr); MS_ASSERT(typePtr != nullptr);
paramTensor->dataType = typePtr->type_id(); paramTensor->dataType = typePtr->type_id();
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
<< paramNode->name();
MS_ASSERT(false); MS_ASSERT(false);
return; return;
} }
paramTensor->dims = paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
->shape();
auto paramValue =
std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
if (paramValue != nullptr) { if (paramValue != nullptr) {
paramTensor->nodeType = schema::NodeType_ValueNode; paramTensor->nodeType = schema::NodeType_ValueNode;
paramTensor->data.resize(paramValue->tensor_size()); paramTensor->data.resize(paramValue->tensor_size());
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
paramValue->tensor_size());
for (auto &ite : paramValue->quant_param()) { for (auto &ite : paramValue->quant_param()) {
auto quantPar = std::make_unique<schema::QuantParamT>(); auto quantPar = std::make_unique<schema::QuantParamT>();
quantPar->scale = ite->scale; quantPar->scale = ite->scale;
...@@ -326,8 +329,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, ...@@ -326,8 +329,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
paramTensor->dataType = paramValue->tensor_type(); paramTensor->dataType = paramValue->tensor_type();
} }
} }
nodeIdMap[paramNode->fullname_with_scope()] = nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor)); meta_graph->allTensors.emplace_back(std::move(paramTensor));
} else if (inputNode->isa<ValueNode>()) { } else if (inputNode->isa<ValueNode>()) {
...@@ -336,19 +338,15 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, ...@@ -336,19 +338,15 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
auto value = valueNode->value(); auto value = valueNode->value();
if (value->isa<lite::tensor::Tensor>()) { if (value->isa<lite::tensor::Tensor>()) {
auto valueAbstract = valueNode->abstract(); auto valueAbstract = valueNode->abstract();
auto abstractTensor = auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
auto typePtr = abstractTensor->element()->GetTypeTrack(); auto typePtr = abstractTensor->element()->GetTypeTrack();
paramTensor->dataType = typePtr->type_id(); paramTensor->dataType = typePtr->type_id();
paramTensor->dims = paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
->shape();
paramTensor->nodeType = schema::NodeType_ValueNode; paramTensor->nodeType = schema::NodeType_ValueNode;
auto data = value->cast<lite::tensor::TensorPtr>(); auto data = value->cast<lite::tensor::TensorPtr>();
paramTensor->data.resize(data->Size()); paramTensor->data.resize(data->Size());
memcpy(paramTensor->data.data(), data->Data(), data->Size()); memcpy(paramTensor->data.data(), data->Data(), data->Size());
nodeIdMap[valueNode->fullname_with_scope()] = nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor)); meta_graph->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<mindspore::Int32Imm>()) { } else if (value->isa<mindspore::Int32Imm>()) {
...@@ -376,30 +374,44 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, ...@@ -376,30 +374,44 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
} }
} }
void AnfExporter::SetOpOutputNode( void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector<schema::TensorT *> &outputTensors,
const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph, schema::CNodeT *fbnode) {
schema::MetaGraphT *graph, schema::CNodeT *cnode) {
MS_ASSERT(nullptr != graph); MS_ASSERT(nullptr != graph);
MS_ASSERT(nullptr != cnode); MS_ASSERT(nullptr != fbnode);
std::string cnodeName = cnode->name; std::string cnodeName = fbnode->name;
if (!outputTensors.empty()) { if (!outputTensors.empty()) {
int i = 0; int i = 0;
for (auto outputTensor : outputTensors) { for (auto outputTensor : outputTensors) {
std::string name = cnodeName + "_o:" + std::to_string(i); std::string name = cnodeName + "_o:" + std::to_string(i);
auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter;
nodeIdMap[name] = graph->allTensors.size(); nodeIdMap[name] = graph->allTensors.size();
cnode->outputIndex.emplace_back(graph->allTensors.size()); fbnode->outputIndex.emplace_back(graph->allTensors.size());
graph->allTensors.emplace_back(msTensor); graph->allTensors.emplace_back(outputTensor);
i++; i++;
} }
return; return;
} }
auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter; if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
cnode->outputIndex.emplace_back(graph->allTensors.size()); auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
nodeIdMap[cnodeName] = graph->allTensors.size(); for (int i = 0; i < tuple->size(); i++) {
graph->allTensors.emplace_back(msTensor); auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter;
fbnode->outputIndex.emplace_back(graph->allTensors.size());
if (tuple->size() == 1) {
nodeIdMap[cnodeName] = graph->allTensors.size();
} else {
std::string name = cnodeName + "_o:" + std::to_string(i);
nodeIdMap[name] = graph->allTensors.size();
}
graph->allTensors.emplace_back(msTensor);
}
} else {
auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter;
fbnode->outputIndex.emplace_back(graph->allTensors.size());
nodeIdMap[cnodeName] = graph->allTensors.size();
graph->allTensors.emplace_back(msTensor);
}
} }
schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) { schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) {
......
...@@ -32,8 +32,8 @@ class AnfExporter { ...@@ -32,8 +32,8 @@ class AnfExporter {
AnfExporter() = default; AnfExporter() = default;
virtual ~AnfExporter() = default; virtual ~AnfExporter() = default;
schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph);
void SetOpOutputNode(const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph, void SetOpOutputNode(const CNodePtr &cnode, const std::vector<schema::TensorT *> &outputTensors,
schema::CNodeT *cnode); schema::MetaGraphT *graph, schema::CNodeT *fbnode);
void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode);
void RemoveIfMakeTuple(const CNodePtr &cnode); void RemoveIfMakeTuple(const CNodePtr &cnode);
bool RemoveIfTupleGetItem(const CNodePtr &cnode); bool RemoveIfTupleGetItem(const CNodePtr &cnode);
...@@ -47,4 +47,3 @@ class AnfExporter { ...@@ -47,4 +47,3 @@ class AnfExporter {
schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph);
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_
...@@ -71,11 +71,11 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { ...@@ -71,11 +71,11 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
for (size_t i = 0; i < meta_graph_->nodes.size(); i++) { for (size_t i = 0; i < meta_graph_->nodes.size(); i++) {
auto &cNode = meta_graph_->nodes.at(i); auto &cNode = meta_graph_->nodes.at(i);
MS_EXCEPTION_IF_NULL(cNode); MS_EXCEPTION_IF_NULL(cNode);
auto tensor_id = cNode->outputIndex.front();
if (nullptr != GetNode(tensor_id)) {
continue;
}
bool flag = false;
if (cNode->outputIndex.size() > 1) {
flag = true;
}
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
cNode->primitive = nullptr; cNode->primitive = nullptr;
auto value_node = NewValueNode(primTValue); auto value_node = NewValueNode(primTValue);
...@@ -90,9 +90,39 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { ...@@ -90,9 +90,39 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
// todo: CheckInputNodeType, the first node should be op; // todo: CheckInputNodeType, the first node should be op;
op_inputs.push_back(node); op_inputs.push_back(node);
} }
auto cnode = func_graph_->NewCNode(op_inputs);
cnode->set_fullname_with_scope(cNode->name); auto new_cnode = func_graph_->NewCNode(op_inputs);
AddNode(tensor_id, cnode); new_cnode->set_fullname_with_scope(cNode->name);
std::vector<uint32_t> out_tensor_ids = cNode->outputIndex;
AbstractBasePtrList ptr_list;
int total = 0;
for (auto out_tensor_id : out_tensor_ids) {
if (nullptr != GetNode(out_tensor_id)) {
ptr_list.push_back(GetNode(out_tensor_id)->abstract());
continue;
}
std::vector<int> shape;
auto &tensor = meta_graph_->allTensors.at(out_tensor_id);
for (int &dim : tensor->dims) {
shape.push_back(dim);
}
auto type_id = static_cast<TypeId>(tensor->dataType);
auto type_ptr = TypeIdToType(type_id);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
auto getItemPrim = NewValueNode(prim::kPrimTupleGetItem);
if (flag) {
auto getItemIndex = NewValueNode(MakeValue<int>(total++));
std::vector<AnfNodePtr> inputs{getItemPrim, new_cnode, getItemIndex};
CNodePtr new_item_cnode = func_graph_->NewCNode(inputs);
AddNode(out_tensor_id, new_item_cnode);
} else {
AddNode(out_tensor_id, new_cnode);
}
ptr_list.push_back(std::move(abstract_tensor));
}
new_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
} }
return RET_OK; return RET_OK;
} }
...@@ -120,4 +150,3 @@ void AnfImporterFromMetaGraphT::AddReturnCNode() { ...@@ -120,4 +150,3 @@ void AnfImporterFromMetaGraphT::AddReturnCNode() {
FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; }
} // namespace mindspore::lite } // namespace mindspore::lite
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册