提交 467c3ef0 编写于 作者: D Dmitry Kurtaev

Add checks for LSTM initial h and c

上级 84336202
......@@ -496,6 +496,7 @@ void ONNXImporter::populateNet(Net dstNet)
runLayer(layerParams, inputs, sliced);
CV_Assert(sliced.size() == 1);
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
outShapes[layerParams.name] = shape(sliced[0]);
continue;
}
}
......@@ -630,6 +631,8 @@ void ONNXImporter::populateNet(Net dstNet)
Mat Wx = getBlob(node_proto, constBlobs, 1);
Mat Wh = getBlob(node_proto, constBlobs, 2);
Mat b = getBlob(node_proto, constBlobs, 3);
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 5)), 0, "Unsupported non zero initial_h");
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 6)), 0, "Unsupported non zero initial_c");
b = b.reshape(1, b.size[0]);
const int numHidden = lstmParams.get<int>("hidden_size");
......@@ -1007,6 +1010,16 @@ void ONNXImporter::populateNet(Net dstNet)
}
else
layerParams.type = "Identity";
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat inp = getBlob(node_proto, constBlobs, 0);
Mat out = inp.reshape(1, outShape);
out.dims = outShape.size(); // to workaround dims == 1
constBlobs.insert(std::make_pair(layerParams.name, out));
outShapes[layerParams.name] = shape(out);
continue;
}
}
else if (layer_type == "Flatten")
{
......@@ -1136,15 +1149,6 @@ void ONNXImporter::populateNet(Net dstNet)
else
layerParams.type = "Identity";
}
else if (layer_type == "ConstantFill" || layer_type == "ConstantOfShape")
{
CV_Assert_N(node_proto.input_size());
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
float value = layerParams.get("value", 0);
Mat fill(inpShape.size(), &inpShape[0], CV_32F, Scalar(value));
constBlobs.insert(std::make_pair(layerParams.name, fill));
continue;
}
else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
{
float fill_value;
......
......@@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P(Test_ONNX_layers, Squeeze)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("squeeze");
}
......@@ -453,12 +455,12 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
TEST_P(Test_ONNX_layers, LSTM)
{
testONNXModels("lstm");
testONNXModels("lstm", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
{
testONNXModels("lstm_bidirectional");
testONNXModels("lstm_bidirectional", npy, 0, 0, false, false);
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册