未验证 提交 e4e774d4 编写于 作者: A Abduragim Shtanchaev 提交者: GitHub

Merge pull request #23475 from Abdurrahheem:lstm_fix_initialization

Fix ONNX parser for single-layer LSTM hidden and cell states #23475

### Fix ONNX parser for single-layer LSTM hidden and cell states

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake


This PR addresses #21118 [issue](https://github.com/opencv/opencv/issues/21118). The problem is that the ONNX parser is unable to read the hidden state and cell state for single-layer LSTMs. This PR fixes the issue by updating the parser to correctly read hidden and cell states.
上级 89c5a758
...@@ -173,9 +173,17 @@ public: ...@@ -173,9 +173,17 @@ public:
CV_CheckEQ(Wh.rows, Wx.rows, ""); CV_CheckEQ(Wh.rows, Wx.rows, "");
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, ""); CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
CV_CheckEQ(Wh.rows, (int)bias.total(), ""); CV_CheckEQ(Wh.rows, (int)bias.total(), "");
CV_CheckEQ(hInternal.cols, Wh.cols, ""); // Only perform these checks if hInternal and cInternal are not empty matrices
CV_CheckEQ(hInternal.cols, cInternal.cols, ""); // e.g. inputs are not given by a user
CV_CheckEQ(hInternal.rows, cInternal.rows, ""); if(!hInternal.empty()){
CV_CheckEQ(hInternal.cols, Wh.cols, "");
}
if(!cInternal.empty()){
CV_CheckEQ(cInternal.cols, Wh.cols, "");
}
if (!hInternal.empty() && !cInternal.empty()){ //otherwise check in forward
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
}
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
// Peephole weights. // Peephole weights.
...@@ -266,7 +274,7 @@ public: ...@@ -266,7 +274,7 @@ public:
std::vector<MatShape> &internals) const CV_OVERRIDE std::vector<MatShape> &internals) const CV_OVERRIDE
{ {
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8)); CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
CV_Assert(inputs.size() == 1); CV_Assert((inputs.size() == 1 || inputs.size() == 3));
const MatShape& inp0 = inputs[0]; const MatShape& inp0 = inputs[0];
const Mat &Wh = blobs[0], &Wx = blobs[1]; const Mat &Wh = blobs[0], &Wx = blobs[1];
...@@ -326,7 +334,7 @@ public: ...@@ -326,7 +334,7 @@ public:
inputs_arr.getMatVector(input); inputs_arr.getMatVector(input);
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8)); CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
CV_Assert(input.size() == 1); CV_Assert((input.size() == 1 || input.size() == 3));
const Mat& inp0 = input[0]; const Mat& inp0 = input[0];
Mat &Wh = blobs[0], &Wx = blobs[1]; Mat &Wh = blobs[0], &Wx = blobs[1];
...@@ -383,8 +391,20 @@ public: ...@@ -383,8 +391,20 @@ public:
Mat Wh = blobs[0]; Mat Wh = blobs[0];
Mat Wx = blobs[1]; Mat Wx = blobs[1];
Mat bias = blobs[2]; Mat bias = blobs[2];
Mat h_0 = blobs[3];
Mat c_0 = blobs[4]; Mat h_0, c_0;
// Handle h_0 and c_0 based on input size
h_0 = (input.size() >= 2) ? input[1].reshape(1, input[1].size[0] * input[1].size[1]) : blobs[3];
c_0 = (input.size() == 3) ? input[2].reshape(1, input[2].size[0] * input[2].size[1]) : blobs[4];
// Perform checks if input size is 2 or 3
if (input.size() >= 2) {
CV_CheckEQ(h_0.cols, Wh.cols, "");
CV_CheckEQ(h_0.cols, c_0.cols, "");
CV_CheckEQ(h_0.rows, c_0.rows, "");
}
Mat pI, pF, pO; Mat pI, pF, pO;
Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs); Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
......
...@@ -1539,10 +1539,17 @@ void transformBlobs(std::vector<Mat>& blobs) ...@@ -1539,10 +1539,17 @@ void transformBlobs(std::vector<Mat>& blobs)
const int numHidden = Wh.size[2]; const int numHidden = Wh.size[2];
Mat h0 = blobs[3]; Mat h0, c0;
h0 = h0.reshape(1, h0.size[0] * h0.size[1]); // check weather input is dynamic or not: hx, cx are given by user.
Mat c0 = blobs[4]; // Resahpe if only they are given
c0 = c0.reshape(1, c0.size[0] * c0.size[1]); if (!blobs[3].empty()){
h0 = blobs[3];
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
}
if (!blobs[4].empty()){
c0 = blobs[4];
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
}
b = b.reshape(1, b.size[0]); b = b.reshape(1, b.size[0]);
Mat bx = b.colRange(0, b.cols / 2); Mat bx = b.colRange(0, b.cols / 2);
...@@ -1569,8 +1576,13 @@ void transformBlobs(std::vector<Mat>& blobs) ...@@ -1569,8 +1576,13 @@ void transformBlobs(std::vector<Mat>& blobs)
blobs[0] = Wh; blobs[0] = Wh;
blobs[1] = Wx; blobs[1] = Wx;
blobs[2] = b.reshape(1, 1); blobs[2] = b.reshape(1, 1);
blobs[3] = h0;
blobs[4] = c0; if (!blobs[3].empty()){
blobs[3] = h0;
}
if (!blobs[4].empty()){
blobs[4] = c0;
}
if (blobs.size() == 5) { if (blobs.size() == 5) {
// so that future patch removing copies can leave all indexing as is // so that future patch removing copies can leave all indexing as is
...@@ -1601,8 +1613,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn ...@@ -1601,8 +1613,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn
Mat blob; Mat blob;
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty()) if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
{ {
blob = getBlob(lstm_proto, idx); if ((idx == 5 || idx == 6) && (constBlobs.find(lstm_proto.input(idx)) == constBlobs.end()))
CV_Assert(shape(blob) == blobShape); {
blob = Mat();
}
else
{
blob = getBlob(lstm_proto, idx);
CV_Assert(shape(blob) == blobShape);
}
} }
else else
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册