提交 fec2c7e7 编写于 作者: S Smirnov Egor

fix Flatten layer

上级 f0712074
......@@ -100,7 +100,6 @@ public:
{
outputShapeVec.push_back(inputs[0][i]);
}
CV_Assert(outputShapeVec.size() <= 4);
outputs.resize(inputs.size(), outputShapeVec);
......
......@@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 1, "");
int axis_ = layerParams.get<int>("axis", 1);
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat input = getBlob(node_proto, 0);
int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims);
int axis = normalize_axis(axis_, input.dims);
std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
out_size.push_back(input.total(axis));
Mat output = input.reshape(1, out_size);
int out_size[2] = {1, 1};
for (int i = 0; i < axis; ++i)
{
out_size[0] *= input.size[i];
}
for (int i = axis; i < input.dims; ++i)
{
out_size[1] *= input.size[i];
}
Mat output = input.reshape(1, 2, out_size);
addConstant(layerParams.name, output);
return;
}
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape inpShape = shapeIt->second;
int axis = normalize_axis(axis_, inpShape.size());
if (axis == 0 || axis == inpShape.size())
{
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1);
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(reshapeLp, proto);
node_proto.set_input(0, reshapeLp.name);
axis += 1;
}
LayerParams first_pass;
first_pass.name = layerParams.name + "/flatten";
CV_Assert(layer_id.find(first_pass.name) == layer_id.end());
first_pass.type = "Flatten";
first_pass.set("axis", 0);
first_pass.set("end_axis", axis - 1);
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(first_pass.name);
addLayer(first_pass, proto);
layerParams.set("axis", 1);
node_proto.set_input(0, first_pass.name);
addLayer(layerParams, node_proto);
}
......
......@@ -17,12 +17,6 @@
"test_elu",
"test_elu_default",
"test_exp",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis4",
"test_leakyrelu",
"test_leakyrelu_default",
"test_logsoftmax_axis_1",
......
......@@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype)
CASE(test_eyelike_without_dtype)
// no filter
CASE(test_flatten_axis0)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_axis1)
// no filter
CASE(test_flatten_axis2)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_axis3)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_default_axis)
// no filter
CASE(test_flatten_negative_axis1)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_negative_axis2)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_negative_axis3)
// no filter
CASE(test_flatten_negative_axis4)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_floor)
// no filter
CASE(test_floor_example)
......
......@@ -7,12 +7,6 @@
"test_castlike_FLOAT_to_STRING_expanded",
"test_castlike_STRING_to_FLOAT_expanded",
"test_concat_1d_axis_negative_1",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis4",
"test_logsoftmax_default_axis",
"test_maxpool_2d_dilations",
"test_maxpool_2d_same_lower",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册