From e4e774d42b7c2902c36c9e3de35c5d13e17263f6 Mon Sep 17 00:00:00 2001 From: Abduragim Shtanchaev <44877829+Abdurrahheem@users.noreply.github.com> Date: Mon, 24 Apr 2023 13:39:41 +0300 Subject: [PATCH] Merge pull request #23475 from Abdurrahheem:lstm_fix_initialization Fix ONNX parser for single-layer LSTM hidden and cell states #23475 ### Fix ONNX parser for single-layer LSTM hidden and cell states ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake This PR addresses #21118 [issue](https://github.com/opencv/opencv/issues/21118). The problem is that the ONNX parser is unable to read the hidden state and cell state for single-layer LSTMs. This PR fixes the issue by updating the parser to correctly read hidden and cell states. --- modules/dnn/src/layers/recurrent_layers.cpp | 34 +++++++++++++++----- modules/dnn/src/onnx/onnx_importer.cpp | 35 ++++++++++++++++----- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/modules/dnn/src/layers/recurrent_layers.cpp b/modules/dnn/src/layers/recurrent_layers.cpp index 3961051c8e..8c3b5810bf 100644 --- a/modules/dnn/src/layers/recurrent_layers.cpp +++ b/modules/dnn/src/layers/recurrent_layers.cpp @@ -173,9 +173,17 @@ public: CV_CheckEQ(Wh.rows, Wx.rows, ""); CV_CheckEQ(Wh.rows, (1 + static_cast(bidirectional))*4*Wh.cols, ""); CV_CheckEQ(Wh.rows, (int)bias.total(), ""); - CV_CheckEQ(hInternal.cols, Wh.cols, ""); - CV_CheckEQ(hInternal.cols, cInternal.cols, ""); - CV_CheckEQ(hInternal.rows, cInternal.rows, ""); + // Only perform these checks if hInternal and cInternal are not empty matrices + // e.g. inputs are not given by a user + if(!hInternal.empty()){ + CV_CheckEQ(hInternal.cols, Wh.cols, ""); + } + if(!cInternal.empty()){ + CV_CheckEQ(cInternal.cols, Wh.cols, ""); + } + if (!hInternal.empty() && !cInternal.empty()){ //otherwise check in forward + CV_CheckEQ(hInternal.rows, cInternal.rows, ""); + } CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); // Peephole weights. @@ -266,7 +274,7 @@ public: std::vector &internals) const CV_OVERRIDE { CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8)); - CV_Assert(inputs.size() == 1); + CV_Assert((inputs.size() == 1 || inputs.size() == 3)); const MatShape& inp0 = inputs[0]; const Mat &Wh = blobs[0], &Wx = blobs[1]; @@ -326,7 +334,7 @@ public: inputs_arr.getMatVector(input); CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8)); - CV_Assert(input.size() == 1); + CV_Assert((input.size() == 1 || input.size() == 3)); const Mat& inp0 = input[0]; Mat &Wh = blobs[0], &Wx = blobs[1]; @@ -383,8 +391,20 @@ public: Mat Wh = blobs[0]; Mat Wx = blobs[1]; Mat bias = blobs[2]; - Mat h_0 = blobs[3]; - Mat c_0 = blobs[4]; + + Mat h_0, c_0; + // Handle h_0 and c_0 based on input size + h_0 = (input.size() >= 2) ? input[1].reshape(1, input[1].size[0] * input[1].size[1]) : blobs[3]; + c_0 = (input.size() == 3) ? input[2].reshape(1, input[2].size[0] * input[2].size[1]) : blobs[4]; + + // Perform checks if input size is 2 or 3 + if (input.size() >= 2) { + CV_CheckEQ(h_0.cols, Wh.cols, ""); + CV_CheckEQ(h_0.cols, c_0.cols, ""); + CV_CheckEQ(h_0.rows, c_0.rows, ""); + } + + Mat pI, pF, pO; Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs); diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index e074d54169..7421fdbc28 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1539,10 +1539,17 @@ void transformBlobs(std::vector& blobs) const int numHidden = Wh.size[2]; - Mat h0 = blobs[3]; - h0 = h0.reshape(1, h0.size[0] * h0.size[1]); - Mat c0 = blobs[4]; - c0 = c0.reshape(1, c0.size[0] * c0.size[1]); + Mat h0, c0; + // check weather input is dynamic or not: hx, cx are given by user. + // Resahpe if only they are given + if (!blobs[3].empty()){ + h0 = blobs[3]; + h0 = h0.reshape(1, h0.size[0] * h0.size[1]); + } + if (!blobs[4].empty()){ + c0 = blobs[4]; + c0 = c0.reshape(1, c0.size[0] * c0.size[1]); + } b = b.reshape(1, b.size[0]); Mat bx = b.colRange(0, b.cols / 2); @@ -1569,8 +1576,13 @@ void transformBlobs(std::vector& blobs) blobs[0] = Wh; blobs[1] = Wx; blobs[2] = b.reshape(1, 1); - blobs[3] = h0; - blobs[4] = c0; + + if (!blobs[3].empty()){ + blobs[3] = h0; + } + if (!blobs[4].empty()){ + blobs[4] = c0; + } if (blobs.size() == 5) { // so that future patch removing copies can leave all indexing as is @@ -1601,8 +1613,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn Mat blob; if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty()) { - blob = getBlob(lstm_proto, idx); - CV_Assert(shape(blob) == blobShape); + if ((idx == 5 || idx == 6) && (constBlobs.find(lstm_proto.input(idx)) == constBlobs.end())) + { + blob = Mat(); + } + else + { + blob = getBlob(lstm_proto, idx); + CV_Assert(shape(blob) == blobShape); + } } else { -- GitLab