Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Opencv
提交
472b71ec
O
Opencv
项目概览
Greenplum
/
Opencv
10 个月 前同步成功
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
Opencv
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
472b71ec
编写于
8月 24, 2018
作者:
D
Dmitry Kurtaev
提交者:
Alexander Alekhin
8月 24, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Merge pull request #12243 from dkurt:dnn_tf_mask_rcnn
* Support Mask-RCNN from TensorFlow * Fix a sample
上级
4f360f8b
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
600 addition
and
153 deletion
+600
-153
modules/dnn/src/layers/crop_and_resize_layer.cpp
modules/dnn/src/layers/crop_and_resize_layer.cpp
+7
-0
modules/dnn/src/layers/detection_output_layer.cpp
modules/dnn/src/layers/detection_output_layer.cpp
+40
-10
modules/dnn/src/layers/resize_layer.cpp
modules/dnn/src/layers/resize_layer.cpp
+11
-5
modules/dnn/test/test_tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
+52
-0
samples/dnn/mask_rcnn.py
samples/dnn/mask_rcnn.py
+143
-0
samples/dnn/tf_text_graph_common.py
samples/dnn/tf_text_graph_common.py
+95
-0
samples/dnn/tf_text_graph_faster_rcnn.py
samples/dnn/tf_text_graph_faster_rcnn.py
+15
-98
samples/dnn/tf_text_graph_mask_rcnn.py
samples/dnn/tf_text_graph_mask_rcnn.py
+230
-0
samples/dnn/tf_text_graph_ssd.py
samples/dnn/tf_text_graph_ssd.py
+7
-40
未找到文件。
modules/dnn/src/layers/crop_and_resize_layer.cpp
浏览文件 @
472b71ec
...
...
@@ -99,6 +99,13 @@ public:
}
}
}
if
(
boxes
.
rows
<
out
.
size
[
0
])
{
// left = top = right = bottom = 0
std
::
vector
<
cv
::
Range
>
dstRanges
(
4
,
Range
::
all
());
dstRanges
[
0
]
=
Range
(
boxes
.
rows
,
out
.
size
[
0
]);
out
(
dstRanges
).
setTo
(
inp
.
ptr
<
float
>
(
0
,
0
,
0
)[
0
]);
}
}
private:
...
...
modules/dnn/src/layers/detection_output_layer.cpp
浏览文件 @
472b71ec
...
...
@@ -115,6 +115,7 @@ public:
// It's true whenever predicted bounding boxes and proposals are normalized to [0, 1].
bool
_bboxesNormalized
;
bool
_clip
;
bool
_groupByClasses
;
enum
{
_numAxes
=
4
};
static
const
std
::
string
_layerName
;
...
...
@@ -183,6 +184,7 @@ public:
_locPredTransposed
=
getParameter
<
bool
>
(
params
,
"loc_pred_transposed"
,
0
,
false
,
false
);
_bboxesNormalized
=
getParameter
<
bool
>
(
params
,
"normalized_bbox"
,
0
,
false
,
true
);
_clip
=
getParameter
<
bool
>
(
params
,
"clip"
,
0
,
false
,
false
);
_groupByClasses
=
getParameter
<
bool
>
(
params
,
"group_by_classes"
,
0
,
false
,
true
);
getCodeType
(
params
);
...
...
@@ -381,7 +383,7 @@ public:
{
count
+=
outputDetections_
(
i
,
&
outputsData
[
count
*
7
],
allDecodedBBoxes
[
i
],
allConfidenceScores
[
i
],
allIndices
[
i
]);
allIndices
[
i
]
,
_groupByClasses
);
}
CV_Assert
(
count
==
numKept
);
}
...
...
@@ -497,7 +499,7 @@ public:
{
count
+=
outputDetections_
(
i
,
&
outputsData
[
count
*
7
],
allDecodedBBoxes
[
i
],
allConfidenceScores
[
i
],
allIndices
[
i
]);
allIndices
[
i
]
,
_groupByClasses
);
}
CV_Assert
(
count
==
numKept
);
}
...
...
@@ -505,9 +507,36 @@ public:
size_t
outputDetections_
(
const
int
i
,
float
*
outputsData
,
const
LabelBBox
&
decodeBBoxes
,
Mat
&
confidenceScores
,
const
std
::
map
<
int
,
std
::
vector
<
int
>
>&
indicesMap
const
std
::
map
<
int
,
std
::
vector
<
int
>
>&
indicesMap
,
bool
groupByClasses
)
{
std
::
vector
<
int
>
dstIndices
;
std
::
vector
<
std
::
pair
<
float
,
int
>
>
allScores
;
for
(
std
::
map
<
int
,
std
::
vector
<
int
>
>::
const_iterator
it
=
indicesMap
.
begin
();
it
!=
indicesMap
.
end
();
++
it
)
{
int
label
=
it
->
first
;
if
(
confidenceScores
.
rows
<=
label
)
CV_Error_
(
cv
::
Error
::
StsError
,
(
"Could not find confidence predictions for label %d"
,
label
));
const
std
::
vector
<
float
>&
scores
=
confidenceScores
.
row
(
label
);
const
std
::
vector
<
int
>&
indices
=
it
->
second
;
const
int
numAllScores
=
allScores
.
size
();
allScores
.
reserve
(
numAllScores
+
indices
.
size
());
for
(
size_t
j
=
0
;
j
<
indices
.
size
();
++
j
)
{
allScores
.
push_back
(
std
::
make_pair
(
scores
[
indices
[
j
]],
numAllScores
+
j
));
}
}
if
(
!
groupByClasses
)
std
::
sort
(
allScores
.
begin
(),
allScores
.
end
(),
util
::
SortScorePairDescend
<
int
>
);
dstIndices
.
resize
(
allScores
.
size
());
for
(
size_t
j
=
0
;
j
<
dstIndices
.
size
();
++
j
)
{
dstIndices
[
allScores
[
j
].
second
]
=
j
;
}
size_t
count
=
0
;
for
(
std
::
map
<
int
,
std
::
vector
<
int
>
>::
const_iterator
it
=
indicesMap
.
begin
();
it
!=
indicesMap
.
end
();
++
it
)
{
...
...
@@ -524,14 +553,15 @@ public:
for
(
size_t
j
=
0
;
j
<
indices
.
size
();
++
j
,
++
count
)
{
int
idx
=
indices
[
j
];
int
dstIdx
=
dstIndices
[
count
];
const
util
::
NormalizedBBox
&
decode_bbox
=
label_bboxes
->
second
[
idx
];
outputsData
[
count
*
7
]
=
i
;
outputsData
[
count
*
7
+
1
]
=
label
;
outputsData
[
count
*
7
+
2
]
=
scores
[
idx
];
outputsData
[
count
*
7
+
3
]
=
decode_bbox
.
xmin
;
outputsData
[
count
*
7
+
4
]
=
decode_bbox
.
ymin
;
outputsData
[
count
*
7
+
5
]
=
decode_bbox
.
xmax
;
outputsData
[
count
*
7
+
6
]
=
decode_bbox
.
ymax
;
outputsData
[
dstIdx
*
7
]
=
i
;
outputsData
[
dstIdx
*
7
+
1
]
=
label
;
outputsData
[
dstIdx
*
7
+
2
]
=
scores
[
idx
];
outputsData
[
dstIdx
*
7
+
3
]
=
decode_bbox
.
xmin
;
outputsData
[
dstIdx
*
7
+
4
]
=
decode_bbox
.
ymin
;
outputsData
[
dstIdx
*
7
+
5
]
=
decode_bbox
.
xmax
;
outputsData
[
dstIdx
*
7
+
6
]
=
decode_bbox
.
ymax
;
}
}
return
count
;
...
...
modules/dnn/src/layers/resize_layer.cpp
浏览文件 @
472b71ec
...
...
@@ -33,9 +33,7 @@ public:
interpolation
=
params
.
get
<
String
>
(
"interpolation"
);
CV_Assert
(
interpolation
==
"nearest"
||
interpolation
==
"bilinear"
);
bool
alignCorners
=
params
.
get
<
bool
>
(
"align_corners"
,
false
);
if
(
alignCorners
)
CV_Error
(
Error
::
StsNotImplemented
,
"Resize with align_corners=true is not implemented"
);
alignCorners
=
params
.
get
<
bool
>
(
"align_corners"
,
false
);
}
bool
getMemoryShapes
(
const
std
::
vector
<
MatShape
>
&
inputs
,
...
...
@@ -66,8 +64,15 @@ public:
outHeight
=
outputs
[
0
].
size
[
2
];
outWidth
=
outputs
[
0
].
size
[
3
];
}
scaleHeight
=
static_cast
<
float
>
(
inputs
[
0
]
->
size
[
2
])
/
outHeight
;
scaleWidth
=
static_cast
<
float
>
(
inputs
[
0
]
->
size
[
3
])
/
outWidth
;
if
(
alignCorners
&&
outHeight
>
1
)
scaleHeight
=
static_cast
<
float
>
(
inputs
[
0
]
->
size
[
2
]
-
1
)
/
(
outHeight
-
1
);
else
scaleHeight
=
static_cast
<
float
>
(
inputs
[
0
]
->
size
[
2
])
/
outHeight
;
if
(
alignCorners
&&
outWidth
>
1
)
scaleWidth
=
static_cast
<
float
>
(
inputs
[
0
]
->
size
[
3
]
-
1
)
/
(
outWidth
-
1
);
else
scaleWidth
=
static_cast
<
float
>
(
inputs
[
0
]
->
size
[
3
])
/
outWidth
;
}
void
forward
(
InputArrayOfArrays
inputs_arr
,
OutputArrayOfArrays
outputs_arr
,
OutputArrayOfArrays
internals_arr
)
CV_OVERRIDE
...
...
@@ -166,6 +171,7 @@ protected:
int
outWidth
,
outHeight
,
zoomFactorWidth
,
zoomFactorHeight
;
String
interpolation
;
float
scaleWidth
,
scaleHeight
;
bool
alignCorners
;
};
...
...
modules/dnn/test/test_tf_importer.cpp
浏览文件 @
472b71ec
...
...
@@ -537,4 +537,56 @@ TEST(Test_TensorFlow, two_inputs)
normAssert
(
out
,
firstInput
+
secondInput
);
}
TEST
(
Test_TensorFlow
,
Mask_RCNN
)
{
std
::
string
proto
=
findDataFile
(
"dnn/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt"
,
false
);
std
::
string
model
=
findDataFile
(
"dnn/mask_rcnn_inception_v2_coco_2018_01_28.pb"
,
false
);
Net
net
=
readNetFromTensorflow
(
model
,
proto
);
Mat
img
=
imread
(
findDataFile
(
"dnn/street.png"
,
false
));
Mat
refDetections
=
blobFromNPY
(
path
(
"mask_rcnn_inception_v2_coco_2018_01_28.detection_out.npy"
));
Mat
refMasks
=
blobFromNPY
(
path
(
"mask_rcnn_inception_v2_coco_2018_01_28.detection_masks.npy"
));
Mat
blob
=
blobFromImage
(
img
,
1.0
f
,
Size
(
800
,
800
),
Scalar
(),
true
,
false
);
net
.
setPreferableBackend
(
DNN_BACKEND_OPENCV
);
net
.
setInput
(
blob
);
// Mask-RCNN predicts bounding boxes and segmentation masks.
std
::
vector
<
String
>
outNames
(
2
);
outNames
[
0
]
=
"detection_out_final"
;
outNames
[
1
]
=
"detection_masks"
;
std
::
vector
<
Mat
>
outs
;
net
.
forward
(
outs
,
outNames
);
Mat
outDetections
=
outs
[
0
];
Mat
outMasks
=
outs
[
1
];
normAssertDetections
(
refDetections
,
outDetections
,
""
,
/*threshold for zero confidence*/
1e-5
);
// Output size of masks is NxCxHxW where
// N - number of detected boxes
// C - number of classes (excluding background)
// HxW - segmentation shape
const
int
numDetections
=
outDetections
.
size
[
2
];
int
masksSize
[]
=
{
1
,
numDetections
,
outMasks
.
size
[
2
],
outMasks
.
size
[
3
]};
Mat
masks
(
4
,
&
masksSize
[
0
],
CV_32F
);
std
::
vector
<
cv
::
Range
>
srcRanges
(
4
,
cv
::
Range
::
all
());
std
::
vector
<
cv
::
Range
>
dstRanges
(
4
,
cv
::
Range
::
all
());
outDetections
=
outDetections
.
reshape
(
1
,
outDetections
.
total
()
/
7
);
for
(
int
i
=
0
;
i
<
numDetections
;
++
i
)
{
// Get a class id for this bounding box and copy mask only for that class.
int
classId
=
static_cast
<
int
>
(
outDetections
.
at
<
float
>
(
i
,
1
));
srcRanges
[
0
]
=
dstRanges
[
1
]
=
cv
::
Range
(
i
,
i
+
1
);
srcRanges
[
1
]
=
cv
::
Range
(
classId
,
classId
+
1
);
outMasks
(
srcRanges
).
copyTo
(
masks
(
dstRanges
));
}
cv
::
Range
topRefMasks
[]
=
{
Range
::
all
(),
Range
(
0
,
numDetections
),
Range
::
all
(),
Range
::
all
()};
normAssert
(
masks
,
refMasks
(
&
topRefMasks
[
0
]));
}
}
samples/dnn/mask_rcnn.py
0 → 100644
浏览文件 @
472b71ec
import
cv2
as
cv
import
argparse
import
numpy
as
np
parser
=
argparse
.
ArgumentParser
(
description
=
'Use this script to run Mask-RCNN object detection and semantic '
'segmentation network from TensorFlow Object Detection API.'
)
parser
.
add_argument
(
'--input'
,
help
=
'Path to input image or video file. Skip this argument to capture frames from a camera.'
)
parser
.
add_argument
(
'--model'
,
required
=
True
,
help
=
'Path to a .pb file with weights.'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'Path to a .pxtxt file contains network configuration.'
)
parser
.
add_argument
(
'--classes'
,
help
=
'Optional path to a text file with names of classes.'
)
parser
.
add_argument
(
'--colors'
,
help
=
'Optional path to a text file with colors for an every class. '
'An every color is represented with three values from 0 to 255 in BGR channels order.'
)
parser
.
add_argument
(
'--width'
,
type
=
int
,
default
=
800
,
help
=
'Preprocess input image by resizing to a specific width.'
)
parser
.
add_argument
(
'--height'
,
type
=
int
,
default
=
800
,
help
=
'Preprocess input image by resizing to a specific height.'
)
parser
.
add_argument
(
'--thr'
,
type
=
float
,
default
=
0.5
,
help
=
'Confidence threshold'
)
args
=
parser
.
parse_args
()
np
.
random
.
seed
(
324
)
# Load names of classes
classes
=
None
if
args
.
classes
:
with
open
(
args
.
classes
,
'rt'
)
as
f
:
classes
=
f
.
read
().
rstrip
(
'
\n
'
).
split
(
'
\n
'
)
# Load colors
colors
=
None
if
args
.
colors
:
with
open
(
args
.
colors
,
'rt'
)
as
f
:
colors
=
[
np
.
array
(
color
.
split
(
' '
),
np
.
uint8
)
for
color
in
f
.
read
().
rstrip
(
'
\n
'
).
split
(
'
\n
'
)]
legend
=
None
def
showLegend
(
classes
):
global
legend
if
not
classes
is
None
and
legend
is
None
:
blockHeight
=
30
assert
(
len
(
classes
)
==
len
(
colors
))
legend
=
np
.
zeros
((
blockHeight
*
len
(
colors
),
200
,
3
),
np
.
uint8
)
for
i
in
range
(
len
(
classes
)):
block
=
legend
[
i
*
blockHeight
:(
i
+
1
)
*
blockHeight
]
block
[:,:]
=
colors
[
i
]
cv
.
putText
(
block
,
classes
[
i
],
(
0
,
blockHeight
/
2
),
cv
.
FONT_HERSHEY_SIMPLEX
,
0.5
,
(
255
,
255
,
255
))
cv
.
namedWindow
(
'Legend'
,
cv
.
WINDOW_NORMAL
)
cv
.
imshow
(
'Legend'
,
legend
)
classes
=
None
def
drawBox
(
frame
,
classId
,
conf
,
left
,
top
,
right
,
bottom
):
# Draw a bounding box.
cv
.
rectangle
(
frame
,
(
left
,
top
),
(
right
,
bottom
),
(
0
,
255
,
0
))
label
=
'%.2f'
%
conf
# Print a label of class.
if
classes
:
assert
(
classId
<
len
(
classes
))
label
=
'%s: %s'
%
(
classes
[
classId
],
label
)
labelSize
,
baseLine
=
cv
.
getTextSize
(
label
,
cv
.
FONT_HERSHEY_SIMPLEX
,
0.5
,
1
)
top
=
max
(
top
,
labelSize
[
1
])
cv
.
rectangle
(
frame
,
(
left
,
top
-
labelSize
[
1
]),
(
left
+
labelSize
[
0
],
top
+
baseLine
),
(
255
,
255
,
255
),
cv
.
FILLED
)
cv
.
putText
(
frame
,
label
,
(
left
,
top
),
cv
.
FONT_HERSHEY_SIMPLEX
,
0.5
,
(
0
,
0
,
0
))
# Load a network
net
=
cv
.
dnn
.
readNet
(
args
.
model
,
args
.
config
)
net
.
setPreferableBackend
(
cv
.
dnn
.
DNN_BACKEND_OPENCV
)
winName
=
'Mask-RCNN in OpenCV'
cv
.
namedWindow
(
winName
,
cv
.
WINDOW_NORMAL
)
cap
=
cv
.
VideoCapture
(
args
.
input
if
args
.
input
else
0
)
legend
=
None
while
cv
.
waitKey
(
1
)
<
0
:
hasFrame
,
frame
=
cap
.
read
()
if
not
hasFrame
:
cv
.
waitKey
()
break
frameH
=
frame
.
shape
[
0
]
frameW
=
frame
.
shape
[
1
]
# Create a 4D blob from a frame.
blob
=
cv
.
dnn
.
blobFromImage
(
frame
,
size
=
(
args
.
width
,
args
.
height
),
swapRB
=
True
,
crop
=
False
)
# Run a model
net
.
setInput
(
blob
)
boxes
,
masks
=
net
.
forward
([
'detection_out_final'
,
'detection_masks'
])
numClasses
=
masks
.
shape
[
1
]
numDetections
=
boxes
.
shape
[
2
]
# Draw segmentation
if
not
colors
:
# Generate colors
colors
=
[
np
.
array
([
0
,
0
,
0
],
np
.
uint8
)]
for
i
in
range
(
1
,
numClasses
+
1
):
colors
.
append
((
colors
[
i
-
1
]
+
np
.
random
.
randint
(
0
,
256
,
[
3
],
np
.
uint8
))
/
2
)
del
colors
[
0
]
boxesToDraw
=
[]
for
i
in
range
(
numDetections
):
box
=
boxes
[
0
,
0
,
i
]
mask
=
masks
[
i
]
score
=
box
[
2
]
if
score
>
args
.
thr
:
classId
=
int
(
box
[
1
])
left
=
int
(
frameW
*
box
[
3
])
top
=
int
(
frameH
*
box
[
4
])
right
=
int
(
frameW
*
box
[
5
])
bottom
=
int
(
frameH
*
box
[
6
])
left
=
max
(
0
,
min
(
left
,
frameW
-
1
))
top
=
max
(
0
,
min
(
top
,
frameH
-
1
))
right
=
max
(
0
,
min
(
right
,
frameW
-
1
))
bottom
=
max
(
0
,
min
(
bottom
,
frameH
-
1
))
boxesToDraw
.
append
([
frame
,
classId
,
score
,
left
,
top
,
right
,
bottom
])
classMask
=
mask
[
classId
]
classMask
=
cv
.
resize
(
classMask
,
(
right
-
left
+
1
,
bottom
-
top
+
1
))
mask
=
(
classMask
>
0.5
)
roi
=
frame
[
top
:
bottom
+
1
,
left
:
right
+
1
][
mask
]
frame
[
top
:
bottom
+
1
,
left
:
right
+
1
][
mask
]
=
(
0.7
*
colors
[
classId
]
+
0.3
*
roi
).
astype
(
np
.
uint8
)
for
box
in
boxesToDraw
:
drawBox
(
*
box
)
# Put efficiency information.
t
,
_
=
net
.
getPerfProfile
()
label
=
'Inference time: %.2f ms'
%
(
t
*
1000.0
/
cv
.
getTickFrequency
())
cv
.
putText
(
frame
,
label
,
(
0
,
15
),
cv
.
FONT_HERSHEY_SIMPLEX
,
0.5
,
(
0
,
255
,
0
))
showLegend
(
classes
)
cv
.
imshow
(
winName
,
frame
)
samples/dnn/tf_text_graph_common.py
浏览文件 @
472b71ec
...
...
@@ -23,3 +23,98 @@ def addConstNode(name, values, graph_def):
node
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
values
),
node
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
node
])
def
addSlice
(
inp
,
out
,
begins
,
sizes
,
graph_def
):
beginsNode
=
NodeDef
()
beginsNode
.
name
=
out
+
'/begins'
beginsNode
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
begins
),
beginsNode
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
beginsNode
])
sizesNode
=
NodeDef
()
sizesNode
.
name
=
out
+
'/sizes'
sizesNode
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
sizes
),
sizesNode
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
sizesNode
])
sliced
=
NodeDef
()
sliced
.
name
=
out
sliced
.
op
=
'Slice'
sliced
.
input
.
append
(
inp
)
sliced
.
input
.
append
(
beginsNode
.
name
)
sliced
.
input
.
append
(
sizesNode
.
name
)
graph_def
.
node
.
extend
([
sliced
])
def
addReshape
(
inp
,
out
,
shape
,
graph_def
):
shapeNode
=
NodeDef
()
shapeNode
.
name
=
out
+
'/shape'
shapeNode
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
shape
),
shapeNode
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
shapeNode
])
reshape
=
NodeDef
()
reshape
.
name
=
out
reshape
.
op
=
'Reshape'
reshape
.
input
.
append
(
inp
)
reshape
.
input
.
append
(
shapeNode
.
name
)
graph_def
.
node
.
extend
([
reshape
])
def
addSoftMax
(
inp
,
out
,
graph_def
):
softmax
=
NodeDef
()
softmax
.
name
=
out
softmax
.
op
=
'Softmax'
text_format
.
Merge
(
'i: -1'
,
softmax
.
attr
[
'axis'
])
softmax
.
input
.
append
(
inp
)
graph_def
.
node
.
extend
([
softmax
])
def
addFlatten
(
inp
,
out
,
graph_def
):
flatten
=
NodeDef
()
flatten
.
name
=
out
flatten
.
op
=
'Flatten'
flatten
.
input
.
append
(
inp
)
graph_def
.
node
.
extend
([
flatten
])
# Removes Identity nodes
def
removeIdentity
(
graph_def
):
identities
=
{}
for
node
in
graph_def
.
node
:
if
node
.
op
==
'Identity'
:
identities
[
node
.
name
]
=
node
.
input
[
0
]
graph_def
.
node
.
remove
(
node
)
for
node
in
graph_def
.
node
:
for
i
in
range
(
len
(
node
.
input
)):
if
node
.
input
[
i
]
in
identities
:
node
.
input
[
i
]
=
identities
[
node
.
input
[
i
]]
def
removeUnusedNodesAndAttrs
(
to_remove
,
graph_def
):
unusedAttrs
=
[
'T'
,
'Tshape'
,
'N'
,
'Tidx'
,
'Tdim'
,
'use_cudnn_on_gpu'
,
'Index'
,
'Tperm'
,
'is_training'
,
'Tpaddings'
]
removedNodes
=
[]
for
i
in
reversed
(
range
(
len
(
graph_def
.
node
))):
op
=
graph_def
.
node
[
i
].
op
name
=
graph_def
.
node
[
i
].
name
if
op
==
'Const'
or
to_remove
(
name
,
op
):
if
op
!=
'Const'
:
removedNodes
.
append
(
name
)
del
graph_def
.
node
[
i
]
else
:
for
attr
in
unusedAttrs
:
if
attr
in
graph_def
.
node
[
i
].
attr
:
del
graph_def
.
node
[
i
].
attr
[
attr
]
# Remove references to removed nodes except Const nodes.
for
node
in
graph_def
.
node
:
for
i
in
reversed
(
range
(
len
(
node
.
input
))):
if
node
.
input
[
i
]
in
removedNodes
:
del
node
.
input
[
i
]
samples/dnn/tf_text_graph_faster_rcnn.py
浏览文件 @
472b71ec
...
...
@@ -6,7 +6,7 @@ from tensorflow.core.framework.node_def_pb2 import NodeDef
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
google.protobuf
import
text_format
from
tf_text_graph_common
import
tensorMsg
,
addConstNode
from
tf_text_graph_common
import
*
parser
=
argparse
.
ArgumentParser
(
description
=
'Run this script to get a text graph of '
'SSD model from TensorFlow Object Detection API. '
...
...
@@ -37,50 +37,17 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
'FirstStageFeatureExtractor/GreaterEqual'
,
'FirstStageFeatureExtractor/LogicalAnd'
)
unusedAttrs
=
[
'T'
,
'Tshape'
,
'N'
,
'Tidx'
,
'Tdim'
,
'use_cudnn_on_gpu'
,
'Index'
,
'Tperm'
,
'is_training'
,
'Tpaddings'
]
# Read the graph.
with
tf
.
gfile
.
FastGFile
(
args
.
input
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
# Removes Identity nodes
def
removeIdentity
():
identities
=
{}
for
node
in
graph_def
.
node
:
if
node
.
op
==
'Identity'
:
identities
[
node
.
name
]
=
node
.
input
[
0
]
graph_def
.
node
.
remove
(
node
)
for
node
in
graph_def
.
node
:
for
i
in
range
(
len
(
node
.
input
)):
if
node
.
input
[
i
]
in
identities
:
node
.
input
[
i
]
=
identities
[
node
.
input
[
i
]]
removeIdentity
()
removedNodes
=
[]
for
i
in
reversed
(
range
(
len
(
graph_def
.
node
))):
op
=
graph_def
.
node
[
i
].
op
name
=
graph_def
.
node
[
i
].
name
removeIdentity
(
graph_def
)
if
op
==
'Const'
or
name
.
startswith
(
scopesToIgnore
)
or
not
name
.
startswith
(
scopesToKeep
):
if
op
!=
'Const'
:
removedNodes
.
append
(
name
)
def
to_remove
(
name
,
op
):
return
name
.
startswith
(
scopesToIgnore
)
or
not
name
.
startswith
(
scopesToKeep
)
del
graph_def
.
node
[
i
]
else
:
for
attr
in
unusedAttrs
:
if
attr
in
graph_def
.
node
[
i
].
attr
:
del
graph_def
.
node
[
i
].
attr
[
attr
]
# Remove references to removed nodes except Const nodes.
for
node
in
graph_def
.
node
:
for
i
in
reversed
(
range
(
len
(
node
.
input
))):
if
node
.
input
[
i
]
in
removedNodes
:
del
node
.
input
[
i
]
removeUnusedNodesAndAttrs
(
to_remove
,
graph_def
)
# Connect input node to the first layer
...
...
@@ -95,68 +62,18 @@ while True:
if
node
.
op
==
'CropAndResize'
:
break
def
addSlice
(
inp
,
out
,
begins
,
sizes
):
beginsNode
=
NodeDef
()
beginsNode
.
name
=
out
+
'/begins'
beginsNode
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
begins
),
beginsNode
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
beginsNode
])
sizesNode
=
NodeDef
()
sizesNode
.
name
=
out
+
'/sizes'
sizesNode
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
sizes
),
sizesNode
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
sizesNode
])
sliced
=
NodeDef
()
sliced
.
name
=
out
sliced
.
op
=
'Slice'
sliced
.
input
.
append
(
inp
)
sliced
.
input
.
append
(
beginsNode
.
name
)
sliced
.
input
.
append
(
sizesNode
.
name
)
graph_def
.
node
.
extend
([
sliced
])
def
addReshape
(
inp
,
out
,
shape
):
shapeNode
=
NodeDef
()
shapeNode
.
name
=
out
+
'/shape'
shapeNode
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
(
shape
),
shapeNode
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
shapeNode
])
reshape
=
NodeDef
()
reshape
.
name
=
out
reshape
.
op
=
'Reshape'
reshape
.
input
.
append
(
inp
)
reshape
.
input
.
append
(
shapeNode
.
name
)
graph_def
.
node
.
extend
([
reshape
])
def
addSoftMax
(
inp
,
out
):
softmax
=
NodeDef
()
softmax
.
name
=
out
softmax
.
op
=
'Softmax'
text_format
.
Merge
(
'i: -1'
,
softmax
.
attr
[
'axis'
])
softmax
.
input
.
append
(
inp
)
graph_def
.
node
.
extend
([
softmax
])
def
addFlatten
(
inp
,
out
):
flatten
=
NodeDef
()
flatten
.
name
=
out
flatten
.
op
=
'Flatten'
flatten
.
input
.
append
(
inp
)
graph_def
.
node
.
extend
([
flatten
])
addReshape
(
'FirstStageBoxPredictor/ClassPredictor/BiasAdd'
,
'FirstStageBoxPredictor/ClassPredictor/reshape_1'
,
[
0
,
-
1
,
2
])
'FirstStageBoxPredictor/ClassPredictor/reshape_1'
,
[
0
,
-
1
,
2
]
,
graph_def
)
addSoftMax
(
'FirstStageBoxPredictor/ClassPredictor/reshape_1'
,
'FirstStageBoxPredictor/ClassPredictor/softmax'
)
# Compare with Reshape_4
'FirstStageBoxPredictor/ClassPredictor/softmax'
,
graph_def
)
# Compare with Reshape_4
addFlatten
(
'FirstStageBoxPredictor/ClassPredictor/softmax'
,
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten'
)
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten'
,
graph_def
)
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
addFlatten
(
'FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd'
,
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten'
)
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten'
,
graph_def
)
proposals
=
NodeDef
()
proposals
.
name
=
'proposals'
# Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
...
...
@@ -218,14 +135,14 @@ graph_def.node.extend([clipByValueNode])
for
node
in
reversed
(
topNodes
):
graph_def
.
node
.
extend
([
node
])
addSoftMax
(
'SecondStageBoxPredictor/Reshape_1'
,
'SecondStageBoxPredictor/Reshape_1/softmax'
)
addSoftMax
(
'SecondStageBoxPredictor/Reshape_1'
,
'SecondStageBoxPredictor/Reshape_1/softmax'
,
graph_def
)
addSlice
(
'SecondStageBoxPredictor/Reshape_1/softmax'
,
'SecondStageBoxPredictor/Reshape_1/slice'
,
[
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
])
[
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
]
,
graph_def
)
addReshape
(
'SecondStageBoxPredictor/Reshape_1/slice'
,
'SecondStageBoxPredictor/Reshape_1/Reshape'
,
[
1
,
-
1
])
'SecondStageBoxPredictor/Reshape_1/Reshape'
,
[
1
,
-
1
]
,
graph_def
)
# Replace Flatten subgraph onto a single node.
for
i
in
reversed
(
range
(
len
(
graph_def
.
node
))):
...
...
@@ -255,7 +172,7 @@ for node in graph_def.node:
################################################################################
### Postprocessing
################################################################################
addSlice
(
'detection_out/clip_by_value'
,
'detection_out/slice'
,
[
0
,
0
,
0
,
3
],
[
-
1
,
-
1
,
-
1
,
4
])
addSlice
(
'detection_out/clip_by_value'
,
'detection_out/slice'
,
[
0
,
0
,
0
,
3
],
[
-
1
,
-
1
,
-
1
,
4
]
,
graph_def
)
variance
=
NodeDef
()
variance
.
name
=
'proposals/variance'
...
...
@@ -271,8 +188,8 @@ varianceEncoder.input.append(variance.name)
text_format
.
Merge
(
'i: 2'
,
varianceEncoder
.
attr
[
"axis"
])
graph_def
.
node
.
extend
([
varianceEncoder
])
addReshape
(
'detection_out/slice'
,
'detection_out/slice/reshape'
,
[
1
,
1
,
-
1
])
addFlatten
(
'variance_encoded'
,
'variance_encoded/flatten'
)
addReshape
(
'detection_out/slice'
,
'detection_out/slice/reshape'
,
[
1
,
1
,
-
1
]
,
graph_def
)
addFlatten
(
'variance_encoded'
,
'variance_encoded/flatten'
,
graph_def
)
detectionOut
=
NodeDef
()
detectionOut
.
name
=
'detection_out_final'
...
...
samples/dnn/tf_text_graph_mask_rcnn.py
0 → 100644
浏览文件 @
472b71ec
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.core.framework.node_def_pb2
import
NodeDef
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
google.protobuf
import
text_format
from
tf_text_graph_common
import
*
parser
=
argparse
.
ArgumentParser
(
description
=
'Run this script to get a text graph of '
'Mask-RCNN model from TensorFlow Object Detection API. '
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.'
)
parser
.
add_argument
(
'--input'
,
required
=
True
,
help
=
'Path to frozen TensorFlow graph.'
)
parser
.
add_argument
(
'--output'
,
required
=
True
,
help
=
'Path to output text graph.'
)
parser
.
add_argument
(
'--num_classes'
,
default
=
90
,
type
=
int
,
help
=
'Number of trained classes.'
)
parser
.
add_argument
(
'--scales'
,
default
=
[
0.25
,
0.5
,
1.0
,
2.0
],
type
=
float
,
nargs
=
'+'
,
help
=
'Hyper-parameter of grid_anchor_generator from a config file.'
)
parser
.
add_argument
(
'--aspect_ratios'
,
default
=
[
0.5
,
1.0
,
2.0
],
type
=
float
,
nargs
=
'+'
,
help
=
'Hyper-parameter of grid_anchor_generator from a config file.'
)
parser
.
add_argument
(
'--features_stride'
,
default
=
16
,
type
=
float
,
nargs
=
'+'
,
help
=
'Hyper-parameter from a config file.'
)
args
=
parser
.
parse_args
()
scopesToKeep
=
(
'FirstStageFeatureExtractor'
,
'Conv'
,
'FirstStageBoxPredictor/BoxEncodingPredictor'
,
'FirstStageBoxPredictor/ClassPredictor'
,
'CropAndResize'
,
'MaxPool2D'
,
'SecondStageFeatureExtractor'
,
'SecondStageBoxPredictor'
,
'Preprocessor/sub'
,
'Preprocessor/mul'
,
'image_tensor'
)
scopesToIgnore
=
(
'FirstStageFeatureExtractor/Assert'
,
'FirstStageFeatureExtractor/Shape'
,
'FirstStageFeatureExtractor/strided_slice'
,
'FirstStageFeatureExtractor/GreaterEqual'
,
'FirstStageFeatureExtractor/LogicalAnd'
)
# Read the graph.
with
tf
.
gfile
.
FastGFile
(
args
.
input
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
removeIdentity
(
graph_def
)
def
to_remove
(
name
,
op
):
return
name
.
startswith
(
scopesToIgnore
)
or
not
name
.
startswith
(
scopesToKeep
)
removeUnusedNodesAndAttrs
(
to_remove
,
graph_def
)
# Connect input node to the first layer
assert
(
graph_def
.
node
[
0
].
op
==
'Placeholder'
)
graph_def
.
node
[
1
].
input
.
insert
(
0
,
graph_def
.
node
[
0
].
name
)
# Temporarily remove top nodes.
topNodes
=
[]
numCropAndResize
=
0
while
True
:
node
=
graph_def
.
node
.
pop
()
topNodes
.
append
(
node
)
if
node
.
op
==
'CropAndResize'
:
numCropAndResize
+=
1
if
numCropAndResize
==
2
:
break
addReshape
(
'FirstStageBoxPredictor/ClassPredictor/BiasAdd'
,
'FirstStageBoxPredictor/ClassPredictor/reshape_1'
,
[
0
,
-
1
,
2
],
graph_def
)
addSoftMax
(
'FirstStageBoxPredictor/ClassPredictor/reshape_1'
,
'FirstStageBoxPredictor/ClassPredictor/softmax'
,
graph_def
)
# Compare with Reshape_4
addFlatten
(
'FirstStageBoxPredictor/ClassPredictor/softmax'
,
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten'
,
graph_def
)
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
addFlatten
(
'FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd'
,
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten'
,
graph_def
)
proposals
=
NodeDef
()
proposals
.
name
=
'proposals'
# Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
proposals
.
op
=
'PriorBox'
proposals
.
input
.
append
(
'FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd'
)
proposals
.
input
.
append
(
graph_def
.
node
[
0
].
name
)
# image_tensor
text_format
.
Merge
(
'b: false'
,
proposals
.
attr
[
"flip"
])
text_format
.
Merge
(
'b: true'
,
proposals
.
attr
[
"clip"
])
text_format
.
Merge
(
'f: %f'
%
args
.
features_stride
,
proposals
.
attr
[
"step"
])
text_format
.
Merge
(
'f: 0.0'
,
proposals
.
attr
[
"offset"
])
text_format
.
Merge
(
tensorMsg
([
0.1
,
0.1
,
0.2
,
0.2
]),
proposals
.
attr
[
"variance"
])
widths
=
[]
heights
=
[]
for
a
in
args
.
aspect_ratios
:
for
s
in
args
.
scales
:
ar
=
np
.
sqrt
(
a
)
heights
.
append
((
args
.
features_stride
**
2
)
*
s
/
ar
)
widths
.
append
((
args
.
features_stride
**
2
)
*
s
*
ar
)
text_format
.
Merge
(
tensorMsg
(
widths
),
proposals
.
attr
[
"width"
])
text_format
.
Merge
(
tensorMsg
(
heights
),
proposals
.
attr
[
"height"
])
graph_def
.
node
.
extend
([
proposals
])
# Compare with Reshape_5
detectionOut
=
NodeDef
()
detectionOut
.
name
=
'detection_out'
detectionOut
.
op
=
'DetectionOutput'
detectionOut
.
input
.
append
(
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten'
)
detectionOut
.
input
.
append
(
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten'
)
detectionOut
.
input
.
append
(
'proposals'
)
text_format
.
Merge
(
'i: 2'
,
detectionOut
.
attr
[
'num_classes'
])
text_format
.
Merge
(
'b: true'
,
detectionOut
.
attr
[
'share_location'
])
text_format
.
Merge
(
'i: 0'
,
detectionOut
.
attr
[
'background_label_id'
])
text_format
.
Merge
(
'f: 0.7'
,
detectionOut
.
attr
[
'nms_threshold'
])
text_format
.
Merge
(
'i: 6000'
,
detectionOut
.
attr
[
'top_k'
])
text_format
.
Merge
(
's: "CENTER_SIZE"'
,
detectionOut
.
attr
[
'code_type'
])
text_format
.
Merge
(
'i: 100'
,
detectionOut
.
attr
[
'keep_top_k'
])
text_format
.
Merge
(
'b: true'
,
detectionOut
.
attr
[
'clip'
])
graph_def
.
node
.
extend
([
detectionOut
])
# Save as text.
for
node
in
reversed
(
topNodes
):
if
node
.
op
!=
'CropAndResize'
:
graph_def
.
node
.
extend
([
node
])
topNodes
.
pop
()
else
:
if
numCropAndResize
==
1
:
break
else
:
graph_def
.
node
.
extend
([
node
])
topNodes
.
pop
()
numCropAndResize
-=
1
addSoftMax
(
'SecondStageBoxPredictor/Reshape_1'
,
'SecondStageBoxPredictor/Reshape_1/softmax'
,
graph_def
)
addSlice
(
'SecondStageBoxPredictor/Reshape_1/softmax'
,
'SecondStageBoxPredictor/Reshape_1/slice'
,
[
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
],
graph_def
)
addReshape
(
'SecondStageBoxPredictor/Reshape_1/slice'
,
'SecondStageBoxPredictor/Reshape_1/Reshape'
,
[
1
,
-
1
],
graph_def
)
# Replace Flatten subgraph onto a single node.
for
i
in
reversed
(
range
(
len
(
graph_def
.
node
))):
if
graph_def
.
node
[
i
].
op
==
'CropAndResize'
:
graph_def
.
node
[
i
].
input
.
insert
(
1
,
'detection_out'
)
if
graph_def
.
node
[
i
].
name
==
'SecondStageBoxPredictor/Reshape'
:
addConstNode
(
'SecondStageBoxPredictor/Reshape/shape2'
,
[
1
,
-
1
,
4
],
graph_def
)
graph_def
.
node
[
i
].
input
.
pop
()
graph_def
.
node
[
i
].
input
.
append
(
'SecondStageBoxPredictor/Reshape/shape2'
)
if
graph_def
.
node
[
i
].
name
in
[
'SecondStageBoxPredictor/Flatten/flatten/Shape'
,
'SecondStageBoxPredictor/Flatten/flatten/strided_slice'
,
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape'
]:
del
graph_def
.
node
[
i
]
for
node
in
graph_def
.
node
:
if
node
.
name
==
'SecondStageBoxPredictor/Flatten/flatten/Reshape'
:
node
.
op
=
'Flatten'
node
.
input
.
pop
()
if
node
.
name
in
[
'FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D'
,
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul'
]:
text_format
.
Merge
(
'b: true'
,
node
.
attr
[
"loc_pred_transposed"
])
################################################################################
### Postprocessing
################################################################################
addSlice
(
'detection_out'
,
'detection_out/slice'
,
[
0
,
0
,
0
,
3
],
[
-
1
,
-
1
,
-
1
,
4
],
graph_def
)
variance
=
NodeDef
()
variance
.
name
=
'proposals/variance'
variance
.
op
=
'Const'
text_format
.
Merge
(
tensorMsg
([
0.1
,
0.1
,
0.2
,
0.2
]),
variance
.
attr
[
"value"
])
graph_def
.
node
.
extend
([
variance
])
varianceEncoder
=
NodeDef
()
varianceEncoder
.
name
=
'variance_encoded'
varianceEncoder
.
op
=
'Mul'
varianceEncoder
.
input
.
append
(
'SecondStageBoxPredictor/Reshape'
)
varianceEncoder
.
input
.
append
(
variance
.
name
)
text_format
.
Merge
(
'i: 2'
,
varianceEncoder
.
attr
[
"axis"
])
graph_def
.
node
.
extend
([
varianceEncoder
])
addReshape
(
'detection_out/slice'
,
'detection_out/slice/reshape'
,
[
1
,
1
,
-
1
],
graph_def
)
addFlatten
(
'variance_encoded'
,
'variance_encoded/flatten'
,
graph_def
)
detectionOut
=
NodeDef
()
detectionOut
.
name
=
'detection_out_final'
detectionOut
.
op
=
'DetectionOutput'
detectionOut
.
input
.
append
(
'variance_encoded/flatten'
)
detectionOut
.
input
.
append
(
'SecondStageBoxPredictor/Reshape_1/Reshape'
)
detectionOut
.
input
.
append
(
'detection_out/slice/reshape'
)
text_format
.
Merge
(
'i: %d'
%
args
.
num_classes
,
detectionOut
.
attr
[
'num_classes'
])
text_format
.
Merge
(
'b: false'
,
detectionOut
.
attr
[
'share_location'
])
text_format
.
Merge
(
'i: %d'
%
(
args
.
num_classes
+
1
),
detectionOut
.
attr
[
'background_label_id'
])
text_format
.
Merge
(
'f: 0.6'
,
detectionOut
.
attr
[
'nms_threshold'
])
text_format
.
Merge
(
's: "CENTER_SIZE"'
,
detectionOut
.
attr
[
'code_type'
])
text_format
.
Merge
(
'i: 100'
,
detectionOut
.
attr
[
'keep_top_k'
])
text_format
.
Merge
(
'b: true'
,
detectionOut
.
attr
[
'clip'
])
text_format
.
Merge
(
'b: true'
,
detectionOut
.
attr
[
'variance_encoded_in_target'
])
text_format
.
Merge
(
'f: 0.3'
,
detectionOut
.
attr
[
'confidence_threshold'
])
text_format
.
Merge
(
'b: false'
,
detectionOut
.
attr
[
'group_by_classes'
])
graph_def
.
node
.
extend
([
detectionOut
])
for
node
in
reversed
(
topNodes
):
graph_def
.
node
.
extend
([
node
])
for
i
in
reversed
(
range
(
len
(
graph_def
.
node
))):
if
graph_def
.
node
[
i
].
op
==
'CropAndResize'
:
graph_def
.
node
[
i
].
input
.
insert
(
1
,
'detection_out_final'
)
break
graph_def
.
node
[
-
1
].
name
=
'detection_masks'
graph_def
.
node
[
-
1
].
op
=
'Sigmoid'
graph_def
.
node
[
-
1
].
input
.
pop
()
tf
.
train
.
write_graph
(
graph_def
,
""
,
args
.
output
,
as_text
=
True
)
samples/dnn/tf_text_graph_ssd.py
浏览文件 @
472b71ec
...
...
@@ -15,7 +15,7 @@ from math import sqrt
from
tensorflow.core.framework.node_def_pb2
import
NodeDef
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
google.protobuf
import
text_format
from
tf_text_graph_common
import
tensorMsg
,
addConstNode
from
tf_text_graph_common
import
*
parser
=
argparse
.
ArgumentParser
(
description
=
'Run this script to get a text graph of '
'SSD model from TensorFlow Object Detection API. '
...
...
@@ -41,10 +41,6 @@ args = parser.parse_args()
keepOps
=
[
'Conv2D'
,
'BiasAdd'
,
'Add'
,
'Relu6'
,
'Placeholder'
,
'FusedBatchNorm'
,
'DepthwiseConv2dNative'
,
'ConcatV2'
,
'Mul'
,
'MaxPool'
,
'AvgPool'
,
'Identity'
]
# Nodes attributes that could be removed because they are not used during import.
unusedAttrs
=
[
'T'
,
'data_format'
,
'Tshape'
,
'N'
,
'Tidx'
,
'Tdim'
,
'use_cudnn_on_gpu'
,
'Index'
,
'Tperm'
,
'is_training'
,
'Tpaddings'
]
# Node with which prefixes should be removed
prefixesToRemove
=
(
'MultipleGridAnchorGenerator/'
,
'Postprocessor/'
,
'Preprocessor/'
)
...
...
@@ -66,7 +62,6 @@ def getUnconnectedNodes():
unconnected
.
remove
(
inp
)
return
unconnected
removedNodes
=
[]
# Detect unfused batch normalization nodes and fuse them.
def
fuse_batch_normalization
():
...
...
@@ -118,41 +113,13 @@ def fuse_batch_normalization():
fuse_batch_normalization
()
# Removes Identity nodes
def
removeIdentity
():
identities
=
{}
for
node
in
graph_def
.
node
:
if
node
.
op
==
'Identity'
:
identities
[
node
.
name
]
=
node
.
input
[
0
]
graph_def
.
node
.
remove
(
node
)
for
node
in
graph_def
.
node
:
for
i
in
range
(
len
(
node
.
input
)):
if
node
.
input
[
i
]
in
identities
:
node
.
input
[
i
]
=
identities
[
node
.
input
[
i
]]
removeIdentity
()
# Remove extra nodes and attributes.
for
i
in
reversed
(
range
(
len
(
graph_def
.
node
))):
op
=
graph_def
.
node
[
i
].
op
name
=
graph_def
.
node
[
i
].
name
removeIdentity
(
graph_def
)
if
(
not
op
in
keepOps
)
or
name
.
startswith
(
prefixesToRemove
):
if
op
!=
'Const'
:
removedNodes
.
append
(
name
)
def
to_remove
(
name
,
op
):
return
(
not
op
in
keepOps
)
or
name
.
startswith
(
prefixesToRemove
)
del
graph_def
.
node
[
i
]
else
:
for
attr
in
unusedAttrs
:
if
attr
in
graph_def
.
node
[
i
].
attr
:
del
graph_def
.
node
[
i
].
attr
[
attr
]
removeUnusedNodesAndAttrs
(
to_remove
,
graph_def
)
# Remove references to removed nodes except Const nodes.
for
node
in
graph_def
.
node
:
for
i
in
reversed
(
range
(
len
(
node
.
input
))):
if
node
.
input
[
i
]
in
removedNodes
:
del
node
.
input
[
i
]
# Connect input node to the first layer
assert
(
graph_def
.
node
[
0
].
op
==
'Placeholder'
)
...
...
@@ -175,8 +142,8 @@ def addConcatNode(name, inputs, axisNodeName):
concat
.
input
.
append
(
axisNodeName
)
graph_def
.
node
.
extend
([
concat
])
addConstNode
(
'concat/axis_flatten'
,
[
-
1
])
addConstNode
(
'PriorBox/concat/axis'
,
[
-
2
])
addConstNode
(
'concat/axis_flatten'
,
[
-
1
]
,
graph_def
)
addConstNode
(
'PriorBox/concat/axis'
,
[
-
2
]
,
graph_def
)
for
label
in
[
'ClassPredictor'
,
'BoxEncodingPredictor'
if
args
.
box_predictor
is
'convolutional'
else
'BoxPredictor'
]:
concatInputs
=
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录