未验证 提交 9ed13323 编写于 作者: L Liubov Batanina 提交者: GitHub

Merge pull request #16722 from l-bat:reshape_opset_11

* Supported Div op for constants

* Added Mul test
上级 57cf1201
......@@ -465,31 +465,6 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
}
}
else if (layer_type == "Div")
{
if (constBlobs.find(node_proto.input(1)) == constBlobs.end())
{
layerParams.type = "Eltwise";
layerParams.set("operation", "div");
}
else
{
Mat blob = getBlob(node_proto, constBlobs, 1);
CV_Assert_N(blob.type() == CV_32F, blob.total());
if (blob.total() == 1)
{
layerParams.set("scale", 1.0f / blob.at<float>(0));
layerParams.type = "Power";
}
else
{
layerParams.type = "Scale";
divide(1.0, blob, blob);
layerParams.blobs.push_back(blob);
layerParams.set("bias_term", false);
}
}
}
else if (layer_type == "Neg")
{
layerParams.type = "Power";
......@@ -638,24 +613,58 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("bias_term", false);
layerParams.set("num_output", layerParams.blobs[0].size[0]);
}
else if (layer_type == "Mul")
else if (layer_type == "Mul" || layer_type == "Div")
{
CV_Assert(node_proto.input_size() == 2);
if (layer_id.find(node_proto.input(1)) == layer_id.end()) {
Mat blob = getBlob(node_proto, constBlobs, 1);
bool isDiv = layer_type == "Div";
int constId = -1;
bool haveVariables = false;
for (int i = 0; i < 2; ++i)
{
if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
constId = i;
else
haveVariables = true;
}
if (constId != -1 && haveVariables)
{
Mat blob = getBlob(node_proto, constBlobs, constId);
blob = blob.reshape(1, 1);
if (blob.total() == 1) {
layerParams.set("scale", blob.at<float>(0));
float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
layerParams.set("scale", coeff);
layerParams.type = "Power";
}
else {
if (isDiv)
divide(1.0, blob, blob);
layerParams.blobs.push_back(blob);
layerParams.type = "Scale";
}
}
else {
layerParams.type = "Eltwise";
layerParams.set("operation", "prod");
layerParams.set("operation", isDiv ? "div" : "prod");
}
if (!haveVariables)
{
Mat inp0 = getBlob(node_proto, constBlobs, 0);
Mat inp1 = getBlob(node_proto, constBlobs, 1);
if (inp0.size != inp1.size)
CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
Mat out;
if (isDiv)
divide(inp0, inp1, out);
else
multiply(inp0, inp1, out);
out = out.reshape(1, inp0.dims, inp0.size);
out.dims = inp0.dims; // to workaround dims == 1
constBlobs.insert(std::make_pair(layerParams.name, out));
continue;
}
}
else if (layer_type == "Conv")
......
......@@ -382,6 +382,8 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
if (target == DNN_TARGET_OPENCL) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
testONNXModels("dynamic_reshape");
testONNXModels("dynamic_reshape_opset_11");
testONNXModels("flatten_by_prod");
}
TEST_P(Test_ONNX_layers, Reshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册