Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Opencv
提交
a67228cd
O
Opencv
项目概览
Greenplum
/
Opencv
大约 1 年 前同步成功
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
Opencv
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a67228cd
编写于
1月 14, 2020
作者:
A
Alexander Alekhin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16291 from dkurt:dnn_onnx_graph_simplifier
上级
f4daf14b
c1c84d2f
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
565 addition
and
192 deletion
+565
-192
modules/dnn/src/graph_simplifier.cpp
modules/dnn/src/graph_simplifier.cpp
+207
-0
modules/dnn/src/graph_simplifier.hpp
modules/dnn/src/graph_simplifier.hpp
+100
-0
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
+157
-0
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
+30
-0
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+5
-0
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
+65
-192
modules/dnn/test/test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+1
-0
未找到文件。
modules/dnn/src/graph_simplifier.cpp
0 → 100644
浏览文件 @
a67228cd
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "precomp.hpp"
#include "graph_simplifier.hpp"
#include <queue>
namespace
cv
{
namespace
dnn
{
Subgraph
::~
Subgraph
()
{}
int
Subgraph
::
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
,
int
input_1
,
int
input_2
,
int
input_3
)
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
return
addNodeToMatch
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
int
Subgraph
::
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
{
for
(
int
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
CV_Assert
(
inputs_
[
i
]
<
(
int
)
nodes
.
size
());
}
nodes
.
push_back
(
op
);
inputs
.
push_back
(
inputs_
);
return
nodes
.
size
()
-
1
;
}
void
Subgraph
::
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
,
int
input_1
,
int
input_2
,
int
input_3
,
int
input_4
,
int
input_5
)
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
,
input_4
,
input_5
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
CV_Assert
(
nodeInputs
[
i
]
<
(
int
)
nodes
.
size
());
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
setFusedNode
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
void
Subgraph
::
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
{
fusedNodeInputs
=
inputs_
;
fusedNodeOp
=
op
;
}
int
Subgraph
::
getInputNodeId
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
node
,
int
inpId
)
{
CV_Assert
(
inpId
<
node
->
getNumInputs
());
std
::
string
name
=
node
->
getInputName
(
inpId
);
// If operation produces several tensors, they are specified by index
// after ':' character. In example, "input:0".
name
=
name
.
substr
(
0
,
name
.
rfind
(
':'
));
const
int
numNodes
=
net
->
getNumNodes
();
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
if
(
net
->
getNodeName
(
i
)
==
name
)
return
i
;
}
CV_Error
(
Error
::
StsParseError
,
"Input node with name "
+
name
+
" not found"
);
}
bool
Subgraph
::
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
{
matchedNodesIds
.
clear
();
targetNodesIds
.
clear
();
std
::
queue
<
int
>
nodesToMatch
;
std
::
queue
<
int
>
targetNodes
;
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
nodes
.
size
()
-
1
);
while
(
!
nodesToMatch
.
empty
())
{
int
nodeToMatch
=
nodesToMatch
.
front
();
int
targetNodeId
=
targetNodes
.
front
();
nodesToMatch
.
pop
();
targetNodes
.
pop
();
if
(
std
::
find
(
matchedNodesIds
.
begin
(),
matchedNodesIds
.
end
(),
nodeToMatch
)
!=
matchedNodesIds
.
end
())
continue
;
const
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
nodeToMatch
);
if
(
node
->
getType
()
!=
nodes
[
targetNodeId
])
return
false
;
std
::
vector
<
int
>&
inputNodes
=
inputs
[
targetNodeId
];
if
(
inputNodes
.
size
()
!=
node
->
getNumInputs
())
return
false
;
for
(
int
j
=
0
;
j
<
inputNodes
.
size
();
++
j
)
{
if
(
nodes
[
inputNodes
[
j
]].
empty
())
// Unknown input node type.
continue
;
nodeId
=
getInputNodeId
(
net
,
node
,
j
);
const
Ptr
<
ImportNodeWrapper
>
inpNode
=
net
->
getNode
(
nodeId
);
if
(
inpNode
->
getType
()
!=
"Const"
)
{
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
inputNodes
[
j
]);
}
else
if
(
nodes
[
inputNodes
[
j
]]
!=
"Const"
)
return
false
;
}
matchedNodesIds
.
push_back
(
nodeToMatch
);
targetNodesIds
.
push_back
(
targetNodeId
);
}
const
int
n
=
matchedNodesIds
.
size
();
std
::
vector
<
std
::
pair
<
int
,
int
>
>
elements
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
elements
[
i
]
=
std
::
make_pair
(
matchedNodesIds
[
i
],
targetNodesIds
[
i
]);
std
::
sort
(
elements
.
begin
(),
elements
.
end
());
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
matchedNodesIds
[
i
]
=
elements
[
i
].
first
;
targetNodesIds
[
i
]
=
elements
[
i
].
second
;
}
return
true
;
}
void
Subgraph
::
replace
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
)
{
// Extract names of input nodes.
std
::
vector
<
std
::
string
>
inputsNames
(
fusedNodeInputs
.
size
());
for
(
int
i
=
0
;
i
<
fusedNodeInputs
.
size
();
++
i
)
{
std
::
string
inpName
;
// Find input node name looking at inputs of fused nodes.
for
(
int
j
=
0
;
j
<
matchedNodesIds
.
size
()
&&
inpName
.
empty
();
++
j
)
{
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
matchedNodesIds
[
j
]);
std
::
vector
<
int
>&
inpIndices
=
inputs
[
targetNodesIds
[
j
]];
CV_Assert
(
node
->
getNumInputs
()
==
inpIndices
.
size
());
for
(
int
k
=
0
;
k
<
inpIndices
.
size
();
++
k
)
{
if
(
inpIndices
[
k
]
==
fusedNodeInputs
[
i
])
{
inpName
=
node
->
getInputName
(
k
);
break
;
}
}
}
CV_Assert
(
!
inpName
.
empty
());
inputsNames
[
i
]
=
inpName
;
}
// Remove matched nodes except the last one. Indices in ascending order are expected.
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
matchedNodesIds
.
back
());
for
(
int
i
=
matchedNodesIds
.
size
()
-
2
;
i
>=
0
;
--
i
)
net
->
removeNode
(
matchedNodesIds
[
i
]);
// Modify the last node to be a fused one.
node
->
setType
(
fusedNodeOp
);
node
->
setInputNames
(
inputsNames
);
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>
inputNodes
(
inputsNames
.
size
());
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
inputNodes
[
i
]
=
net
->
getNode
(
getInputNodeId
(
net
,
node
,
i
));
}
finalize
(
net
,
node
,
inputNodes
);
}
void
Subgraph
::
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
)
{}
void
simplifySubgraphs
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
Ptr
<
Subgraph
>
>&
patterns
)
{
int
numNodes
=
net
->
getNumNodes
();
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
for
(
int
j
=
0
;
j
<
patterns
.
size
();
++
j
)
{
if
(
patterns
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
patterns
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
}
}}
// namespace cv::dnn
modules/dnn/src/graph_simplifier.hpp
0 → 100644
浏览文件 @
a67228cd
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
#define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
#include <string>
#include <opencv2/core.hpp>
namespace
cv
{
namespace
dnn
{
class
ImportNodeWrapper
{
public:
virtual
~
ImportNodeWrapper
()
{};
virtual
int
getNumInputs
()
const
=
0
;
virtual
std
::
string
getInputName
(
int
idx
)
const
=
0
;
virtual
std
::
string
getType
()
const
=
0
;
virtual
void
setType
(
const
std
::
string
&
type
)
=
0
;
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
=
0
;
};
class
ImportGraphWrapper
{
public:
virtual
~
ImportGraphWrapper
()
{};
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
=
0
;
virtual
int
getNumNodes
()
const
=
0
;
virtual
std
::
string
getNodeName
(
int
idx
)
const
=
0
;
virtual
void
removeNode
(
int
idx
)
=
0
;
};
class
Subgraph
// Interface to match and replace subgraphs.
{
public:
virtual
~
Subgraph
();
// Add a node to be matched in the origin graph. Specify ids of nodes that
// are expected to be inputs. Returns id of a newly added node.
// TODO: Replace inputs to std::vector<int> in C++11
int
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
);
int
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
);
// Specify resulting node. All the matched nodes in subgraph excluding
// input nodes will be fused into this single node.
// TODO: Replace inputs to std::vector<int> in C++11
void
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
,
int
input_4
=
-
1
,
int
input_5
=
-
1
);
void
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
);
static
int
getInputNodeId
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
node
,
int
inpId
);
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
);
// Fuse matched subgraph.
void
replace
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
);
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
);
private:
std
::
vector
<
std
::
string
>
nodes
;
// Nodes to be matched in the origin graph.
std
::
vector
<
std
::
vector
<
int
>
>
inputs
;
// Connections of an every node to it's inputs.
std
::
string
fusedNodeOp
;
// Operation name of resulting fused node.
std
::
vector
<
int
>
fusedNodeInputs
;
// Inputs of fused node.
};
void
simplifySubgraphs
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
Ptr
<
Subgraph
>
>&
patterns
);
}}
// namespace dnn, namespace cv
#endif // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
0 → 100644
浏览文件 @
a67228cd
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "../precomp.hpp"
#include "../graph_simplifier.hpp"
#include "onnx_graph_simplifier.hpp"
#include <queue>
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
// This wrapper can behave differently for fake input nodes and real graph nodes.
class
ONNXNodeWrapper
:
public
ImportNodeWrapper
{
public:
ONNXNodeWrapper
(
opencv_onnx
::
NodeProto
*
_node
=
0
)
:
node
(
_node
)
{}
virtual
int
getNumInputs
()
const
CV_OVERRIDE
{
return
node
?
node
->
input_size
()
:
0
;
}
virtual
std
::
string
getInputName
(
int
idx
)
const
CV_OVERRIDE
{
CV_Assert_N
(
node
,
idx
<
node
->
input_size
());
return
node
->
input
(
idx
);
}
virtual
std
::
string
getType
()
const
CV_OVERRIDE
{
return
node
?
node
->
op_type
()
:
""
;
}
virtual
void
setType
(
const
std
::
string
&
type
)
CV_OVERRIDE
{
CV_Assert
(
node
);
node
->
set_op_type
(
type
);
}
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
CV_OVERRIDE
{
CV_Assert
(
node
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
node
->
add_input
(
inputs
[
i
]);
}
opencv_onnx
::
NodeProto
*
node
;
};
// ONNX graph's inputs are separate from nodes so we index them before the rest of nodes.
class
ONNXGraphWrapper
:
public
ImportGraphWrapper
{
public:
ONNXGraphWrapper
(
opencv_onnx
::
GraphProto
&
_net
)
:
net
(
_net
)
{
numInputs
=
net
.
input_size
();
}
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
CV_OVERRIDE
{
opencv_onnx
::
NodeProto
*
node
=
0
;
if
(
idx
>=
numInputs
)
node
=
net
.
mutable_node
(
idx
-
numInputs
);
return
makePtr
<
ONNXNodeWrapper
>
(
node
);
}
virtual
int
getNumNodes
()
const
CV_OVERRIDE
{
return
numInputs
+
net
.
node_size
();
}
virtual
std
::
string
getNodeName
(
int
idx
)
const
CV_OVERRIDE
{
if
(
idx
<
numInputs
)
return
net
.
input
(
idx
).
name
();
else
return
net
.
node
(
idx
-
numInputs
).
output
(
0
);
}
virtual
void
removeNode
(
int
idx
)
CV_OVERRIDE
{
CV_Assert
(
idx
>=
numInputs
);
net
.
mutable_node
()
->
DeleteSubrange
(
idx
-
numInputs
,
1
);
}
private:
int
numInputs
;
opencv_onnx
::
GraphProto
&
net
;
};
class
SoftMaxSubgraph
:
public
Subgraph
{
public:
SoftMaxSubgraph
()
{
int
input
=
addNodeToMatch
(
""
);
int
inpExp
=
addNodeToMatch
(
"Exp"
,
input
);
int
sum
=
addNodeToMatch
(
"ReduceSum"
,
inpExp
);
addNodeToMatch
(
"Div"
,
inpExp
,
sum
);
setFusedNode
(
"Softmax"
,
input
);
}
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
if
(
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
{
Ptr
<
ImportNodeWrapper
>
sum
=
net
->
getNode
(
matchedNodesIds
[
1
]);
opencv_onnx
::
NodeProto
*
node
=
sum
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
for
(
int
i
=
0
;
i
<
node
->
attribute_size
();
i
++
)
{
opencv_onnx
::
AttributeProto
attr
=
node
->
attribute
(
i
);
if
(
attr
.
name
()
!=
"axes"
)
continue
;
if
(
attr
.
ints_size
()
!=
1
)
CV_Error
(
Error
::
StsNotImplemented
,
format
(
"Unexpected number of axes: %d"
,
attr
.
ints_size
()));
axis
=
attr
.
ints
(
0
);
return
true
;
}
CV_Error
(
Error
::
StsNotImplemented
,
"Missed axes attribute"
);
}
return
false
;
}
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
)
CV_OVERRIDE
{
opencv_onnx
::
NodeProto
*
node
=
fusedNode
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
AttributeProto
*
attr
=
node
->
add_attribute
();
attr
->
set_name
(
"axis"
);
attr
->
set_i
(
axis
);
}
private:
int
axis
;
};
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
)
{
std
::
vector
<
Ptr
<
Subgraph
>
>
subgraphs
;
subgraphs
.
push_back
(
makePtr
<
SoftMaxSubgraph
>
());
simplifySubgraphs
(
Ptr
<
ImportGraphWrapper
>
(
new
ONNXGraphWrapper
(
net
)),
subgraphs
);
}
CV__DNN_EXPERIMENTAL_NS_END
}}
// namespace cv::dnn
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
0 → 100644
浏览文件 @
a67228cd
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
#define __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
#include "../precomp.hpp"
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsuggest-override"
#endif
#include "opencv-onnx.pb.h"
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic pop
#endif
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
);
CV__DNN_EXPERIMENTAL_NS_END
}}
// namespace dnn, namespace cv
#endif // __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
modules/dnn/src/onnx/onnx_importer.cpp
浏览文件 @
a67228cd
...
...
@@ -26,6 +26,8 @@
#pragma GCC diagnostic pop
#endif
#include "onnx_graph_simplifier.hpp"
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
...
...
@@ -326,6 +328,9 @@ void ONNXImporter::populateNet(Net dstNet)
{
CV_Assert
(
model_proto
.
has_graph
());
opencv_onnx
::
GraphProto
graph_proto
=
model_proto
.
graph
();
simplifySubgraphs
(
graph_proto
);
std
::
map
<
std
::
string
,
Mat
>
constBlobs
=
getGraphTensors
(
graph_proto
);
// List of internal blobs shapes.
std
::
map
<
std
::
string
,
MatShape
>
outShapes
;
...
...
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
浏览文件 @
a67228cd
...
...
@@ -9,6 +9,7 @@
#ifdef HAVE_PROTOBUF
#include "../graph_simplifier.hpp"
#include "tf_graph_simplifier.hpp"
#include <queue>
...
...
@@ -18,203 +19,87 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
using
::
google
::
protobuf
::
RepeatedField
;
using
::
google
::
protobuf
::
MapPair
;
class
Subgraph
// Interface to match and replace TensorFlow subgraphs.
class
TFNodeWrapper
:
public
ImportNodeWrapper
{
public:
virtual
~
Subgraph
(
)
{}
TFNodeWrapper
(
tensorflow
::
NodeDef
*
_node
)
:
node
(
_node
)
{}
// Add a node to be matched in the origin graph. Specify ids of nodes that
// are expected to be inputs. Returns id of a newly added node.
// TODO: Replace inputs to std::vector<int> in C++11
int
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
)
virtual
int
getNumInputs
()
const
CV_OVERRIDE
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
return
addNodeToMatch
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
return
node
->
input_size
();
}
int
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
virtual
std
::
string
getInputName
(
int
idx
)
const
CV_OVERRIDE
{
for
(
int
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
CV_Assert
(
inputs_
[
i
]
<
(
int
)
nodes
.
size
());
}
nodes
.
push_back
(
op
);
inputs
.
push_back
(
inputs_
);
return
nodes
.
size
()
-
1
;
return
node
->
input
(
idx
);
}
// Specify resulting node. All the matched nodes in subgraph excluding
// input nodes will be fused into this single node.
// TODO: Replace inputs to std::vector<int> in C++11
void
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
,
int
input_4
=
-
1
,
int
input_5
=
-
1
)
virtual
std
::
string
getType
()
const
CV_OVERRIDE
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
,
input_4
,
input_5
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
CV_Assert
(
nodeInputs
[
i
]
<
(
int
)
nodes
.
size
());
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
setFusedNode
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
return
node
->
op
();
}
v
oid
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
v
irtual
void
setType
(
const
std
::
string
&
type
)
CV_OVERRIDE
{
fusedNodeInputs
=
inputs_
;
fusedNodeOp
=
op
;
node
->
set_op
(
type
);
}
static
int
getInputNodeId
(
const
tensorflow
::
GraphDef
&
net
,
const
tensorflow
::
NodeDef
&
node
,
int
inpId
)
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
CV_OVERRIDE
{
CV_Assert
(
inpId
<
node
.
input_size
());
std
::
string
name
=
node
.
input
(
inpId
);
// If operation produces several tensors, they are specified by index
// after ':' character. In example, "input:0".
name
=
name
.
substr
(
0
,
name
.
rfind
(
':'
));
const
int
numNodes
=
net
.
node_size
();
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
if
(
net
.
node
(
i
).
name
()
==
name
)
return
i
;
}
CV_Error
(
Error
::
StsParseError
,
"Input node with name "
+
name
+
" not found"
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
node
->
add_input
(
inputs
[
i
]);
}
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
{
matchedNodesIds
.
clear
();
targetNodesIds
.
clear
();
std
::
queue
<
int
>
nodesToMatch
;
std
::
queue
<
int
>
targetNodes
;
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
nodes
.
size
()
-
1
);
while
(
!
nodesToMatch
.
empty
())
{
int
nodeToMatch
=
nodesToMatch
.
front
();
int
targetNodeId
=
targetNodes
.
front
();
nodesToMatch
.
pop
();
targetNodes
.
pop
();
if
(
std
::
find
(
matchedNodesIds
.
begin
(),
matchedNodesIds
.
end
(),
nodeToMatch
)
!=
matchedNodesIds
.
end
())
continue
;
const
tensorflow
::
NodeDef
&
node
=
net
.
node
(
nodeToMatch
);
if
(
node
.
op
()
!=
nodes
[
targetNodeId
])
return
false
;
std
::
vector
<
int
>&
inputNodes
=
inputs
[
targetNodeId
];
if
(
inputNodes
.
size
()
!=
node
.
input_size
())
return
false
;
tensorflow
::
NodeDef
*
node
;
};
for
(
int
j
=
0
;
j
<
inputNodes
.
size
();
++
j
)
{
if
(
nodes
[
inputNodes
[
j
]].
empty
())
// Unknown input node type.
continue
;
nodeId
=
getInputNodeId
(
net
,
node
,
j
);
const
tensorflow
::
NodeDef
&
inpNode
=
net
.
node
(
nodeId
);
if
(
inpNode
.
op
()
!=
"Const"
)
{
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
inputNodes
[
j
]);
}
else
if
(
nodes
[
inputNodes
[
j
]]
!=
"Const"
)
return
false
;
}
matchedNodesIds
.
push_back
(
nodeToMatch
);
targetNodesIds
.
push_back
(
targetNodeId
);
}
class
TFGraphWrapper
:
public
ImportGraphWrapper
{
public:
TFGraphWrapper
(
tensorflow
::
GraphDef
&
_net
)
:
net
(
_net
)
{}
const
int
n
=
matchedNodesIds
.
size
();
std
::
vector
<
std
::
pair
<
int
,
int
>
>
elements
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
elements
[
i
]
=
std
::
make_pair
(
matchedNodesIds
[
i
],
targetNodesIds
[
i
]);
std
::
sort
(
elements
.
begin
(),
elements
.
end
());
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
matchedNodesIds
[
i
]
=
elements
[
i
].
first
;
targetNodesIds
[
i
]
=
elements
[
i
].
second
;
}
return
true
;
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
CV_OVERRIDE
{
return
makePtr
<
TFNodeWrapper
>
(
net
.
mutable_node
(
idx
));
}
// Fuse matched subgraph.
void
replace
(
tensorflow
::
GraphDef
&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
)
virtual
int
getNumNodes
()
const
CV_OVERRIDE
{
// Extract names of input nodes.
std
::
vector
<
std
::
string
>
inputsNames
(
fusedNodeInputs
.
size
());
for
(
int
i
=
0
;
i
<
fusedNodeInputs
.
size
();
++
i
)
{
std
::
string
inpName
;
// Find input node name looking at inputs of fused nodes.
for
(
int
j
=
0
;
j
<
matchedNodesIds
.
size
()
&&
inpName
.
empty
();
++
j
)
{
const
tensorflow
::
NodeDef
&
node
=
net
.
node
(
matchedNodesIds
[
j
]);
std
::
vector
<
int
>&
inpIndices
=
inputs
[
targetNodesIds
[
j
]];
CV_Assert
(
node
.
input_size
()
==
inpIndices
.
size
());
for
(
int
k
=
0
;
k
<
inpIndices
.
size
();
++
k
)
{
if
(
inpIndices
[
k
]
==
fusedNodeInputs
[
i
])
{
inpName
=
node
.
input
(
k
);
break
;
}
}
}
CV_Assert
(
!
inpName
.
empty
());
inputsNames
[
i
]
=
inpName
;
}
// Remove matched nodes except the last one. Indices in ascending order are expected.
tensorflow
::
NodeDef
*
node
=
net
.
mutable_node
(
matchedNodesIds
.
back
());
for
(
int
i
=
matchedNodesIds
.
size
()
-
2
;
i
>=
0
;
--
i
)
net
.
mutable_node
()
->
DeleteSubrange
(
matchedNodesIds
[
i
],
1
);
return
net
.
node_size
();
}
// Modify the last node to be a fused one.
node
->
set_op
(
fusedNodeOp
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
node
->
add_input
(
inputsNames
[
i
]);
}
virtual
std
::
string
getNodeName
(
int
idx
)
const
CV_OVERRIDE
{
return
net
.
node
(
idx
).
name
();
}
std
::
vector
<
tensorflow
::
NodeDef
*>
inputNodes
(
inputsNames
.
size
());
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
inputNodes
[
i
]
=
net
.
mutable_node
(
getInputNodeId
(
net
,
*
node
,
i
));
}
finalize
(
net
,
node
,
inputNodes
);
virtual
void
removeNode
(
int
idx
)
CV_OVERRIDE
{
net
.
mutable_node
()
->
DeleteSubrange
(
idx
,
1
);
}
virtual
void
finalize
(
tensorflow
::
GraphDef
&
,
tensorflow
::
NodeDef
*
,
std
::
vector
<
tensorflow
::
NodeDef
*>&
)
{}
tensorflow
::
GraphDef
&
net
;
};
private:
std
::
vector
<
std
::
string
>
nodes
;
// Nodes to be matched in the origin graph.
std
::
vector
<
std
::
vector
<
int
>
>
inputs
;
// Connections of an every node to it's inputs.
class
TFSubgraph
:
public
Subgraph
{
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
netWrapper
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNodeWrapper
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
)
CV_OVERRIDE
{
std
::
vector
<
tensorflow
::
NodeDef
*>
inputNodes
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
inputNodes
[
i
]
=
inputs
[
i
].
dynamicCast
<
TFNodeWrapper
>
()
->
node
;
finalize
(
netWrapper
.
dynamicCast
<
TFGraphWrapper
>
()
->
net
,
fusedNodeWrapper
.
dynamicCast
<
TFNodeWrapper
>
()
->
node
,
inputNodes
);
}
std
::
string
fusedNodeOp
;
// Operation name of resulting fused node.
std
::
vector
<
int
>
fusedNodeInputs
;
// Inputs of fused node.
virtual
void
finalize
(
tensorflow
::
GraphDef
&
,
tensorflow
::
NodeDef
*
fusedNode
,
std
::
vector
<
tensorflow
::
NodeDef
*>&
inputNodes
)
{}
};
class
BatchNormSubgraph
:
public
Subgraph
class
BatchNormSubgraph
:
public
TF
Subgraph
{
public:
BatchNormSubgraph
()
...
...
@@ -250,7 +135,7 @@ public:
}
};
class
BatchNormNoGammaSubgraph
:
public
Subgraph
class
BatchNormNoGammaSubgraph
:
public
TF
Subgraph
{
public:
BatchNormNoGammaSubgraph
()
...
...
@@ -366,20 +251,21 @@ public:
setFusedNode
(
"Relu6"
,
input
);
}
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
if
(
!
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
return
false
;
Mat
maxValue
=
getTensorContent
(
net
.
node
(
matchedNodesIds
.
front
()
+
1
).
attr
().
at
(
"value"
).
tensor
());
tensorflow
::
NodeDef
*
node
=
net
->
getNode
(
matchedNodesIds
.
front
()
+
1
).
dynamicCast
<
TFNodeWrapper
>
()
->
node
;
Mat
maxValue
=
getTensorContent
(
node
->
attr
().
at
(
"value"
).
tensor
());
return
maxValue
.
type
()
==
CV_32FC1
&&
maxValue
.
total
()
==
1
&&
maxValue
.
at
<
float
>
(
0
)
==
6
;
}
};
// Keras' reshape stores output shape in separate Const nodes by one value.
// Need to merge them into a single Const node.
class
ReshapeKerasSubgraph
:
public
Subgraph
class
ReshapeKerasSubgraph
:
public
TF
Subgraph
{
public:
ReshapeKerasSubgraph
(
int
_numOutDims
)
:
numOutDims
(
_numOutDims
)
...
...
@@ -402,15 +288,15 @@ public:
setFusedNode
(
"Reshape"
,
ids
);
}
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
const
tensorflow
::
NodeDef
&
node
=
net
.
n
ode
(
nodeId
);
if
(
node
.
input_size
()
==
0
)
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getN
ode
(
nodeId
);
if
(
node
->
getNumInputs
()
==
0
)
return
false
;
inpName
=
node
.
input
(
0
);
inpName
=
node
->
getInputName
(
0
);
return
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
);
}
...
...
@@ -457,7 +343,7 @@ public:
}
};
class
DeconvolutionValidKerasSubgraph
:
public
Subgraph
class
DeconvolutionValidKerasSubgraph
:
public
TF
Subgraph
{
public:
DeconvolutionValidKerasSubgraph
()
...
...
@@ -518,7 +404,7 @@ public:
}
};
class
DeconvolutionSameKerasSubgraph
:
public
Subgraph
class
DeconvolutionSameKerasSubgraph
:
public
TF
Subgraph
{
public:
DeconvolutionSameKerasSubgraph
()
...
...
@@ -608,7 +494,7 @@ public:
};
// In case of resizing by factor.
class
UpsamplingKerasSubgraph
:
public
Subgraph
class
UpsamplingKerasSubgraph
:
public
TF
Subgraph
{
public:
UpsamplingKerasSubgraph
(
const
std
::
string
&
type
)
...
...
@@ -703,7 +589,7 @@ public:
}
};
class
KerasMVNSubgraph
:
public
Subgraph
class
KerasMVNSubgraph
:
public
TF
Subgraph
{
public:
KerasMVNSubgraph
()
...
...
@@ -758,20 +644,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
ReshapeAsShapeSubgraph
()));
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
KerasMVNSubgraph
()));
int
numNodes
=
net
.
node_size
();
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
for
(
int
j
=
0
;
j
<
subgraphs
.
size
();
++
j
)
{
if
(
subgraphs
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
subgraphs
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
simplifySubgraphs
(
Ptr
<
ImportGraphWrapper
>
(
new
TFGraphWrapper
(
net
)),
subgraphs
);
}
void
RemoveIdentityOps
(
tensorflow
::
GraphDef
&
net
)
...
...
modules/dnn/test/test_onnx_importer.cpp
浏览文件 @
a67228cd
...
...
@@ -399,6 +399,7 @@ TEST_P(Test_ONNX_layers, Softmax)
{
testONNXModels
(
"softmax"
);
testONNXModels
(
"log_softmax"
,
npy
,
0
,
0
,
false
,
false
);
testONNXModels
(
"softmax_unfused"
);
}
TEST_P
(
Test_ONNX_layers
,
Split_EltwiseMax
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录