From 11d565ca629d5b36993752941472a26244600e79 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 18 Mar 2020 00:00:24 +0300 Subject: [PATCH] Fix LSTM from ONNX with batch==1 --- modules/dnn/src/layers/recurrent_layers.cpp | 9 +- modules/dnn/src/onnx/onnx_importer.cpp | 97 ++++++++++++++------- 2 files changed, 69 insertions(+), 37 deletions(-) diff --git a/modules/dnn/src/layers/recurrent_layers.cpp b/modules/dnn/src/layers/recurrent_layers.cpp index 3f9a229516..26d2ea9de5 100644 --- a/modules/dnn/src/layers/recurrent_layers.cpp +++ b/modules/dnn/src/layers/recurrent_layers.cpp @@ -110,10 +110,11 @@ public: const Mat& Wh = blobs[0]; const Mat& Wx = blobs[1]; const Mat& bias = blobs[2]; - CV_Assert(Wh.dims == 2 && Wx.dims == 2); - CV_Assert(Wh.rows == Wx.rows); - CV_Assert(Wh.rows == 4*Wh.cols); - CV_Assert(Wh.rows == (int)bias.total()); + CV_CheckEQ(Wh.dims, 2, ""); + CV_CheckEQ(Wx.dims, 2, ""); + CV_CheckEQ(Wh.rows, Wx.rows, ""); + CV_CheckEQ(Wh.rows, 4*Wh.cols, ""); + CV_CheckEQ(Wh.rows, (int)bias.total(), ""); CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); // Peephole weights. diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 2bcba9e6ad..b243a986e7 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -49,6 +49,11 @@ class ONNXImporter LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto); bool isCeilMode(const LayerParams& layerParams); + void addLayer(Net& dstNet, LayerParams& layerParams, + const opencv_onnx::NodeProto& node_proto, + std::map& layer_id, + std::map& outShapes); + public: ONNXImporter(const char *onnxFile) @@ -259,6 +264,42 @@ Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto, return constBlob->second; } +void ONNXImporter::addLayer(Net& dstNet, LayerParams& layerParams, + const opencv_onnx::NodeProto& node_proto, + std::map& layer_id, + std::map& outShapes) +{ + std::map::iterator layerId; + std::map::iterator shapeIt; + + int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams); + for (int i = 0; i < node_proto.output_size(); ++i) + { + layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i))); + } + + std::vector layerInpShapes, layerOutShapes, layerInternalShapes; + int inpNum = 0; + for (int j = 0; j < node_proto.input_size(); j++) { + layerId = layer_id.find(node_proto.input(j)); + if (layerId != layer_id.end()) { + dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum); + ++inpNum; + // Collect input shapes. + shapeIt = outShapes.find(node_proto.input(j)); + CV_Assert(shapeIt != outShapes.end()); + layerInpShapes.push_back(shapeIt->second); + } + } + // Compute shape of output blob for this layer. + Ptr layer = dstNet.getLayer(id); + layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes); + for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i) + { + outShapes[node_proto.output(i)] = layerOutShapes[i]; + } +} + void ONNXImporter::populateNet(Net dstNet) { CV_Assert(model_proto.has_graph()); @@ -581,13 +622,16 @@ void ONNXImporter::populateNet(Net dstNet) } else if (layer_type == "LSTM") { + LayerParams lstmParams = layerParams; + lstmParams.name += "/lstm"; + // https://pytorch.org/docs/stable/nn.html#lstm CV_Assert(node_proto.input_size() == 7); Mat Wx = getBlob(node_proto, constBlobs, 1); Mat Wh = getBlob(node_proto, constBlobs, 2); Mat b = getBlob(node_proto, constBlobs, 3); - const int numHidden = Wh.size[2]; + const int numHidden = lstmParams.get("hidden_size"); Wx = Wx.reshape(1, Wx.size[1]); Wh = Wh.reshape(1, Wh.size[1]); @@ -612,10 +656,24 @@ void ONNXImporter::populateNet(Net dstNet) } std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]); } - layerParams.blobs.resize(3); - layerParams.blobs[0] = Wh; - layerParams.blobs[1] = Wx; - layerParams.blobs[2] = b; + + lstmParams.blobs.resize(3); + lstmParams.blobs[0] = Wh; + lstmParams.blobs[1] = Wx; + lstmParams.blobs[2] = b; + + node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name + addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes); + + MatShape lstmShape = outShapes[node_proto.output(0)]; + + // Add fake 1 as it is done in ONNX + lstmShape.insert(lstmShape.begin() + 1, 1); + + layerParams.type = "Reshape"; + layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size())); + node_proto.set_input(0, lstmParams.name); // redirect input to LSTM + node_proto.set_output(0, layerParams.name); // keep origin LSTM's name } else if (layer_type == "ImageScaler") { @@ -1228,34 +1286,7 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j)); } } - - int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams); - for (int i = 0; i < node_proto.output_size(); ++i) - { - layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i))); - } - - std::vector layerInpShapes, layerOutShapes, layerInternalShapes; - int inpNum = 0; - for (int j = 0; j < node_proto.input_size(); j++) { - layerId = layer_id.find(node_proto.input(j)); - if (layerId != layer_id.end()) { - dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum); - ++inpNum; - // Collect input shapes. - shapeIt = outShapes.find(node_proto.input(j)); - CV_Assert(shapeIt != outShapes.end()); - layerInpShapes.push_back(shapeIt->second); - } - } - - // Compute shape of output blob for this layer. - Ptr layer = dstNet.getLayer(id); - layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes); - for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i) - { - outShapes[node_proto.output(i)] = layerOutShapes[i]; - } + addLayer(dstNet, layerParams, node_proto, layer_id, outShapes); } } -- GitLab