From fec2c7e715d3eaeca20e7051afd4476c4412aef6 Mon Sep 17 00:00:00 2001 From: Smirnov Egor Date: Thu, 16 Dec 2021 22:41:47 +0300 Subject: [PATCH] fix Flatten layer --- modules/dnn/src/layers/flatten_layer.cpp | 1 - modules/dnn/src/onnx/onnx_importer.cpp | 57 +++++++++++++++++-- ...ance_layer_filter__halide_denylist.inl.hpp | 6 -- ...conformance_layer_filter__openvino.inl.hpp | 24 ++------ ...e_layer_filter_opencv_all_denylist.inl.hpp | 6 -- 5 files changed, 58 insertions(+), 36 deletions(-) diff --git a/modules/dnn/src/layers/flatten_layer.cpp b/modules/dnn/src/layers/flatten_layer.cpp index c59b71248e..1e0b010167 100644 --- a/modules/dnn/src/layers/flatten_layer.cpp +++ b/modules/dnn/src/layers/flatten_layer.cpp @@ -100,7 +100,6 @@ public: { outputShapeVec.push_back(inputs[0][i]); } - CV_Assert(outputShapeVec.size() <= 4); outputs.resize(inputs.size(), outputShapeVec); diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index df404a93d3..5d88844e61 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod addLayer(layerParams, node_proto); } -void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) { + opencv_onnx::NodeProto node_proto = node_proto_; CV_CheckEQ(node_proto.input_size(), 1, ""); + int axis_ = layerParams.get("axis", 1); if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) { Mat input = getBlob(node_proto, 0); - int axis = normalize_axis(layerParams.get("axis", 1), input.dims); + int axis = normalize_axis(axis_, input.dims); - std::vector out_size(&input.size[0], &input.size[0] + axis); - out_size.push_back(input.total(axis)); - Mat output = input.reshape(1, out_size); + int out_size[2] = {1, 1}; + for (int i = 0; i < axis; ++i) + { + out_size[0] *= input.size[i]; + } + for (int i = axis; i < input.dims; ++i) + { + out_size[1] *= input.size[i]; + } + + Mat output = input.reshape(1, 2, out_size); addConstant(layerParams.name, output); return; } + IterShape_t shapeIt = outShapes.find(node_proto.input(0)); + CV_Assert(shapeIt != outShapes.end()); + MatShape inpShape = shapeIt->second; + int axis = normalize_axis(axis_, inpShape.size()); + + if (axis == 0 || axis == inpShape.size()) + { + LayerParams reshapeLp; + reshapeLp.name = layerParams.name + "/reshape"; + reshapeLp.type = "Reshape"; + CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end()); + + inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1); + reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size())); + + opencv_onnx::NodeProto proto; + proto.add_input(node_proto.input(0)); + proto.add_output(reshapeLp.name); + addLayer(reshapeLp, proto); + node_proto.set_input(0, reshapeLp.name); + axis += 1; + } + + LayerParams first_pass; + first_pass.name = layerParams.name + "/flatten"; + CV_Assert(layer_id.find(first_pass.name) == layer_id.end()); + first_pass.type = "Flatten"; + first_pass.set("axis", 0); + first_pass.set("end_axis", axis - 1); + + opencv_onnx::NodeProto proto; + proto.add_input(node_proto.input(0)); + proto.add_output(first_pass.name); + addLayer(first_pass, proto); + + layerParams.set("axis", 1); + node_proto.set_input(0, first_pass.name); addLayer(layerParams, node_proto); } diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp index dd0a249081..08938c9ad7 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp @@ -17,12 +17,6 @@ "test_elu", "test_elu_default", "test_exp", -"test_flatten_axis0", -"test_flatten_axis2", -"test_flatten_axis3", -"test_flatten_negative_axis1", -"test_flatten_negative_axis2", -"test_flatten_negative_axis4", "test_leakyrelu", "test_leakyrelu_default", "test_logsoftmax_axis_1", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp index 25bb8dff9a..c3b62a7f75 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp @@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype) CASE(test_eyelike_without_dtype) // no filter CASE(test_flatten_axis0) -#if INF_ENGINE_VER_MAJOR_EQ(2021040000) - SKIP; -#endif + // no filter CASE(test_flatten_axis1) // no filter CASE(test_flatten_axis2) -#if INF_ENGINE_VER_MAJOR_EQ(2021040000) - SKIP; -#endif + // no filter CASE(test_flatten_axis3) -#if INF_ENGINE_VER_MAJOR_EQ(2021040000) - SKIP; -#endif + // no filter CASE(test_flatten_default_axis) // no filter CASE(test_flatten_negative_axis1) -#if INF_ENGINE_VER_MAJOR_EQ(2021040000) - SKIP; -#endif + // no filter CASE(test_flatten_negative_axis2) -#if INF_ENGINE_VER_MAJOR_EQ(2021040000) - SKIP; -#endif + // no filter CASE(test_flatten_negative_axis3) // no filter CASE(test_flatten_negative_axis4) -#if INF_ENGINE_VER_MAJOR_EQ(2021040000) - SKIP; -#endif + // no filter CASE(test_floor) // no filter CASE(test_floor_example) diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp index cff1e93aa0..c9966f11d5 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp @@ -7,12 +7,6 @@ "test_castlike_FLOAT_to_STRING_expanded", "test_castlike_STRING_to_FLOAT_expanded", "test_concat_1d_axis_negative_1", -"test_flatten_axis0", -"test_flatten_axis2", -"test_flatten_axis3", -"test_flatten_negative_axis1", -"test_flatten_negative_axis2", -"test_flatten_negative_axis4", "test_logsoftmax_default_axis", "test_maxpool_2d_dilations", "test_maxpool_2d_same_lower", -- GitLab