提交 a6f9170f 编写于 作者: D Dmitry Kurtaev

Add ONNX's padding import

上级 850053f9
......@@ -96,6 +96,8 @@ Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
for (int i = 0; i < tensor_proto.dims_size(); i++) {
sizes.push_back(tensor_proto.dims(i));
}
if (sizes.empty())
sizes.assign(1, 1);
if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
if (!tensor_proto.float_data().empty()) {
......@@ -173,11 +175,31 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot
}
else if(attribute_name == "pads")
{
CV_Assert(attribute_proto.ints_size() == 4);
lp.set("pad_t", saturate_cast<int32_t>(attribute_proto.ints(0)));
lp.set("pad_l", saturate_cast<int32_t>(attribute_proto.ints(1)));
lp.set("pad_b", saturate_cast<int32_t>(attribute_proto.ints(2)));
lp.set("pad_r", saturate_cast<int32_t>(attribute_proto.ints(3)));
if (node_proto.op_type() == "Pad")
{
// Padding layer.
// Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
// We need to shuffle it to begin0, end0, begin1, end1, ...
CV_Assert(attribute_proto.ints_size() % 2 == 0);
const int dims = attribute_proto.ints_size() / 2;
std::vector<int32_t> paddings;
paddings.reserve(attribute_proto.ints_size());
for (int i = 0; i < dims; ++i)
{
paddings.push_back(attribute_proto.ints(i));
paddings.push_back(attribute_proto.ints(dims + i));
}
lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
}
else
{
// Convolution or pooling.
CV_Assert(attribute_proto.ints_size() == 4);
lp.set("pad_t", saturate_cast<int32_t>(attribute_proto.ints(0)));
lp.set("pad_l", saturate_cast<int32_t>(attribute_proto.ints(1)));
lp.set("pad_b", saturate_cast<int32_t>(attribute_proto.ints(2)));
lp.set("pad_r", saturate_cast<int32_t>(attribute_proto.ints(3)));
}
}
else if(attribute_name == "auto_pad")
{
......@@ -543,6 +565,10 @@ void ONNXImporter::populateNet(Net dstNet)
replaceLayerParam(layerParams, "shape", "dim");
}
}
else if (layer_type == "Pad")
{
layerParams.type = "Padding";
}
else
{
for (int j = 0; j < node_proto.input_size(); j++) {
......
......@@ -129,6 +129,11 @@ TEST_P(Test_ONNX_layers, Constant)
testONNXModels("constant");
}
TEST_P(Test_ONNX_layers, Padding)
{
testONNXModels("padding");
}
TEST_P(Test_ONNX_layers, MultyInputs)
{
const String model = _tf("models/multy_inputs.onnx");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册