未验证 提交 2645ee90 编写于 作者: L Liubov Batanina 提交者: GitHub

Merge pull request #16735 from l-bat:flatten_const_onnx

* Supported Flatten for constant nodes

* Added default axis

* Refactoring

* Refactoring

* Added cast layer

* Fix comments

* Add Cast for layers
上级 0e6ce501
......@@ -431,9 +431,20 @@ void ONNXImporter::populateNet(Net dstNet)
{
bool isSub = layer_type == "Sub";
CV_CheckEQ(node_proto.input_size(), 2, "");
if (layer_id.find(node_proto.input(1)) == layer_id.end())
bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
if (is_const_0 && is_const_1)
{
Mat blob = getBlob(node_proto, constBlobs, 1);
Mat blob_0 = getBlob(node_proto, constBlobs, 0);
Mat blob_1 = getBlob(node_proto, constBlobs, 1);
CV_Assert(blob_0.size == blob_1.size);
Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
constBlobs.insert(std::make_pair(layerParams.name, output));
continue;
}
else if (is_const_0 || is_const_1)
{
Mat blob = getBlob(node_proto, constBlobs, is_const_0 ? 0 : 1);
blob = blob.reshape(1, 1);
if (blob.total() == 1) {
layerParams.type = "Power";
......@@ -808,6 +819,21 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("end_axis", axis);
layerParams.type = "Flatten";
}
else if (layer_type == "Flatten")
{
CV_CheckEQ(node_proto.input_size(), 1, "");
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat input = getBlob(node_proto, constBlobs, 0);
int axis = clamp(layerParams.get<int>("axis", 1), input.dims);
std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
out_size.push_back(input.total(axis));
Mat output = input.reshape(1, out_size);
constBlobs.insert(std::make_pair(layerParams.name, output));
continue;
}
}
else if (layer_type == "Unsqueeze")
{
CV_Assert(node_proto.input_size() == 1);
......@@ -896,6 +922,31 @@ void ONNXImporter::populateNet(Net dstNet)
constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
continue;
}
else if (layer_type == "Cast")
{
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat blob = getBlob(node_proto, constBlobs, 0);
int type;
switch (layerParams.get<int>("to"))
{
case opencv_onnx::TensorProto_DataType_FLOAT: type = CV_32F; break;
case opencv_onnx::TensorProto_DataType_UINT8: type = CV_8U; break;
case opencv_onnx::TensorProto_DataType_UINT16: type = CV_16U; break;
case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break;
case opencv_onnx::TensorProto_DataType_INT8:
case opencv_onnx::TensorProto_DataType_INT16:
case opencv_onnx::TensorProto_DataType_INT32:
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
default: type = blob.type();
}
blob.convertTo(blob, type);
constBlobs.insert(std::make_pair(layerParams.name, blob));
continue;
}
else
layerParams.type = "Identity";
}
else if (layer_type == "Gather")
{
CV_Assert(node_proto.input_size() == 2);
......
......@@ -187,6 +187,11 @@ TEST_P(Test_ONNX_layers, MaxPooling_Sigmoid)
testONNXModels("maxpooling_sigmoid");
}
TEST_P(Test_ONNX_layers, Cast)
{
testONNXModels("cast");
}
TEST_P(Test_ONNX_layers, Concatenation)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
......@@ -377,6 +382,7 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
testONNXModels("dynamic_reshape");
testONNXModels("dynamic_reshape_opset_11");
testONNXModels("flatten_by_prod");
testONNXModels("flatten_const");
}
TEST_P(Test_ONNX_layers, Reshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册