提交 5e80191d 编写于 作者: L Lubov Batanina 提交者: Alexander Alekhin

Merge pull request #14697 from l-bat:Slice_ONNX

* Support Slice layer in ONNX importer

* Add IE support

* Fix ONNX importer

* Fix Slice
上级 254f88f8
......@@ -174,16 +174,16 @@ public:
for (int i = 0; i < outputs.size(); ++i)
{
CV_Assert(sliceRanges[i].size() <= inpShape.dims());
// Clamp.
for (int j = 0; j < sliceRanges[i].size(); ++j)
{
sliceRanges[i][j] = clamp(sliceRanges[i][j], inpShape[j]);
}
// Fill the rest of ranges.
for (int j = sliceRanges[i].size(); j < inpShape.dims(); ++j)
{
sliceRanges[i].push_back(Range::all());
}
// Clamp.
for (int j = 0; j < sliceRanges[i].size(); ++j)
{
sliceRanges[i][j] = clamp(sliceRanges[i][j], inpShape[j]);
}
}
}
......
......@@ -401,6 +401,47 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX");
layerParams.set("global_pooling", true);
}
else if (layer_type == "Slice")
{
if (layerParams.has("steps")) {
DictValue steps = layerParams.get("steps");
for (int i = 0; i < steps.size(); ++i) {
if (steps.get<int>(i) != 1)
CV_Error(Error::StsNotImplemented,
"Slice layer only supports steps = 1");
}
}
int axis = 0;
if (layerParams.has("axes")) {
DictValue axes = layerParams.get("axes");
for (int i = 1; i < axes.size(); ++i) {
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
}
axis = axes.get<int>(0);
}
layerParams.set("axis", axis);
DictValue starts = layerParams.get("starts");
DictValue ends = layerParams.get("ends");
CV_Assert(starts.size() == ends.size());
std::vector<int> begin;
std::vector<int> end;
if (axis > 0) {
begin.resize(axis, 0);
end.resize(axis, -1);
}
for (int i = 0; i < starts.size(); ++i)
{
begin.push_back(starts.get<int>(i));
int finish = ends.get<int>(i);
end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
}
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
}
else if (layer_type == "Add" || layer_type == "Sum")
{
if (layer_id.find(node_proto.input(1)) == layer_id.end())
......
......@@ -245,6 +245,11 @@ TEST_P(Test_ONNX_layers, Reshape)
testONNXModels("unsqueeze");
}
TEST_P(Test_ONNX_layers, Slice)
{
testONNXModels("slice");
}
TEST_P(Test_ONNX_layers, Softmax)
{
testONNXModels("softmax");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册