Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
fbe1d120
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fbe1d120
编写于
9月 15, 2020
作者:
S
still-wait
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add solov2 model
上级
a8103906
变更
19
展开全部
隐藏空白更改
内联
并排
Showing
19 changed file
with
1612 addition
and
49 deletion
+1612
-49
configs/solov2/README.md
configs/solov2/README.md
+22
-0
configs/solov2/solov2_r50_fpn_1x.yml
configs/solov2/solov2_r50_fpn_1x.yml
+62
-0
configs/solov2/solov2_reader.yml
configs/solov2/solov2_reader.yml
+99
-0
deploy/python/infer.py
deploy/python/infer.py
+38
-8
deploy/python/visualize.py
deploy/python/visualize.py
+66
-16
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+160
-0
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+77
-14
ppdet/modeling/__init__.py
ppdet/modeling/__init__.py
+2
-0
ppdet/modeling/anchor_heads/__init__.py
ppdet/modeling/anchor_heads/__init__.py
+2
-0
ppdet/modeling/anchor_heads/solov2_head.py
ppdet/modeling/anchor_heads/solov2_head.py
+531
-0
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/solov2.py
ppdet/modeling/architectures/solov2.py
+187
-0
ppdet/modeling/mask_head/__init__.py
ppdet/modeling/mask_head/__init__.py
+19
-0
ppdet/modeling/mask_head/solo_mask_head.py
ppdet/modeling/mask_head/solo_mask_head.py
+152
-0
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+82
-5
ppdet/utils/coco_eval.py
ppdet/utils/coco_eval.py
+78
-0
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+31
-2
tools/eval.py
tools/eval.py
+0
-4
tools/export_model.py
tools/export_model.py
+2
-0
未找到文件。
configs/solov2/README.md
0 → 100644
浏览文件 @
fbe1d120
# SOLOv2 (Segmenting Objects by Locations) for instance segmentation
## Introduction
-
SOLOv2 is a fast instance segmentation framework with strong performance:
[
https://arxiv.org/abs/2003.10152
](
https://arxiv.org/abs/2003.10152
)
```
@misc{wang2020solov2,
title={SOLOv2: Dynamic, Faster and Stronger},
author={Xinlong Wang and Rufeng Zhang and Tao Kong and Lei Li and Chunhua Shen},
year={2020},
eprint={2003.10152},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Model Zoo
| Backbone | Multi-scale training | Lr schd | Inf time (fps) | Mask AP | Download | Configs |
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| R50-FPN | False | 1x | - | 34.7 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml
)
|
configs/solov2/solov2_r50_fpn_1x.yml
0 → 100644
浏览文件 @
fbe1d120
architecture
:
SOLOv2
use_gpu
:
true
max_iters
:
90000
snapshot_iter
:
10000
log_smooth_window
:
20
save_dir
:
output
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric
:
COCO
weights
:
output/solov2_r50_fpn_1x/model_final
num_classes
:
81
SOLOv2
:
backbone
:
ResNet
fpn
:
FPN
bbox_head
:
SOLOv2Head
mask_head
:
SOLOv2MaskHead
batch_size
:
2
ResNet
:
depth
:
50
feature_maps
:
[
2
,
3
,
4
,
5
]
freeze_at
:
2
norm_type
:
bn
FPN
:
max_level
:
6
min_level
:
2
num_chan
:
256
spatial_scale
:
[
0.03125
,
0.0625
,
0.125
,
0.25
]
reverse_out
:
True
SOLOv2Head
:
seg_feat_channels
:
512
stacked_convs
:
4
num_grids
:
[
40
,
36
,
24
,
16
,
12
]
kernel_out_channels
:
256
SOLOv2MaskHead
:
out_channels
:
128
start_level
:
0
end_level
:
3
num_classes
:
256
LearningRate
:
base_lr
:
0.01
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
60000
,
80000
]
-
!LinearWarmup
start_factor
:
0.
steps
:
1000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
_READER_
:
'
solov2_reader.yml'
configs/solov2/solov2_reader.yml
0 → 100644
浏览文件 @
fbe1d120
TrainReader
:
batch_size
:
2
worker_num
:
2
inputs_def
:
fields
:
[
'
image'
,
'
im_id'
,
'
gt_segm'
]
dataset
:
!COCODataSet
dataset_dir
:
dataset/coco
anno_path
:
annotations/instances_train2017.json
image_dir
:
train2017
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!Poly2Mask
{}
-
!ResizeImage
target_size
:
800
max_size
:
1333
interp
:
1
use_cv2
:
true
resize_box
:
true
-
!RandomFlipImage
prob
:
0.5
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!Permute
to_bgr
:
false
channel_first
:
true
batch_transforms
:
-
!PadBatch
pad_to_stride
:
32
-
!Gt2Solov2Target
num_grids
:
[
40
,
36
,
24
,
16
,
12
]
scale_ranges
:
[[
1
,
96
],
[
48
,
192
],
[
96
,
384
],
[
192
,
768
],
[
384
,
2048
]]
coord_sigma
:
0.2
shuffle
:
True
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!ResizeImage
interp
:
1
max_size
:
1333
target_size
:
800
use_cv2
:
true
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!Permute
channel_first
:
true
to_bgr
:
false
batch_transforms
:
-
!PadBatch
pad_to_stride
:
32
use_padded_im_info
:
false
batch_size
:
1
shuffle
:
false
drop_last
:
false
drop_empty
:
false
worker_num
:
2
TestReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
im_shape'
]
dataset
:
!ImageFolder
anno_path
:
dataset/coco/annotations/instances_val2017.json
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!ResizeImage
interp
:
1
max_size
:
1333
target_size
:
800
use_cv2
:
true
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!Permute
channel_first
:
true
to_bgr
:
false
batch_transforms
:
-
!PadBatch
pad_to_stride
:
32
use_padded_im_info
:
false
deploy/python/infer.py
浏览文件 @
fbe1d120
...
...
@@ -30,6 +30,7 @@ RESIZE_SCALE_SET = {
'RCNN'
,
'RetinaNet'
,
'FCOS'
,
'SOLOv2'
,
}
SUPPORT_MODELS
=
{
...
...
@@ -41,6 +42,7 @@ SUPPORT_MODELS = {
'Face'
,
'TTF'
,
'FCOS'
,
'SOLOv2'
,
}
...
...
@@ -85,7 +87,8 @@ class Resize(object):
max_size
,
use_cv2
=
True
,
image_shape
=
None
,
interp
=
cv2
.
INTER_LINEAR
):
interp
=
cv2
.
INTER_LINEAR
,
resize_box
=
False
):
self
.
target_size
=
target_size
self
.
max_size
=
max_size
self
.
image_shape
=
image_shape
...
...
@@ -251,7 +254,7 @@ class PadStride(object):
pad_w
=
int
(
np
.
ceil
(
float
(
im_w
)
/
coarsest_stride
)
*
coarsest_stride
)
padding_im
=
np
.
zeros
((
im_c
,
pad_h
,
pad_w
),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
im
im_info
[
'
resize
_shape'
]
=
padding_im
.
shape
[
1
:]
im_info
[
'
pad
_shape'
]
=
padding_im
.
shape
[
1
:]
return
padding_im
,
im_info
...
...
@@ -268,23 +271,29 @@ def create_inputs(im, im_info, model_arch='YOLO'):
inputs
[
'image'
]
=
im
origin_shape
=
list
(
im_info
[
'origin_shape'
])
resize_shape
=
list
(
im_info
[
'resize_shape'
])
pad_shape
=
list
(
im_info
[
'pad_shape'
])
if
'pad_shape'
in
im_info
else
list
(
im_info
[
'resize_shape'
])
scale_x
,
scale_y
=
im_info
[
'scale'
]
if
'YOLO'
in
model_arch
:
im_size
=
np
.
array
([
origin_shape
]).
astype
(
'int32'
)
inputs
[
'im_size'
]
=
im_size
elif
'RetinaNet'
or
'EfficientDet'
in
model_arch
:
elif
'RetinaNet'
in
model_arch
or
'EfficientDet'
in
model_arch
:
scale
=
scale_x
im_info
=
np
.
array
([
resize
_shape
+
[
scale
]]).
astype
(
'float32'
)
im_info
=
np
.
array
([
pad
_shape
+
[
scale
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
elif
(
'RCNN'
in
model_arch
)
or
(
'FCOS'
in
model_arch
):
scale
=
scale_x
im_info
=
np
.
array
([
resize
_shape
+
[
scale
]]).
astype
(
'float32'
)
im_info
=
np
.
array
([
pad
_shape
+
[
scale
]]).
astype
(
'float32'
)
im_shape
=
np
.
array
([
origin_shape
+
[
1.
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
inputs
[
'im_shape'
]
=
im_shape
elif
'TTF'
in
model_arch
:
scale_factor
=
np
.
array
([
scale_x
,
scale_y
]
*
2
).
astype
(
'float32'
)
inputs
[
'scale_factor'
]
=
scale_factor
elif
'SOLOv2'
in
model_arch
:
scale
=
scale_x
im_info
=
np
.
array
([
resize_shape
+
[
scale
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
return
inputs
...
...
@@ -405,10 +414,15 @@ def visualize(image_file,
results
,
labels
,
mask_resolution
=
14
,
output_dir
=
'output/'
):
output_dir
=
'output/'
,
threshold
=
0.5
):
# visualize the predict result
im
=
visualize_box_mask
(
image_file
,
results
,
labels
,
mask_resolution
=
mask_resolution
)
image_file
,
results
,
labels
,
mask_resolution
=
mask_resolution
,
threshold
=
threshold
)
img_name
=
os
.
path
.
split
(
image_file
)[
-
1
]
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
...
...
@@ -516,6 +530,11 @@ class Detector():
ms
=
(
t2
-
t1
)
*
1000.0
/
repeats
print
(
"Inference: {} ms per batch image"
.
format
(
ms
))
if
self
.
config
.
arch
==
'SOLOv2'
:
return
dict
(
segm
=
np
.
array
(
outs
[
2
]),
label
=
np
.
array
(
outs
[
0
]),
score
=
np
.
array
(
outs
[
1
]))
np_boxes
=
np
.
array
(
outs
[
0
])
if
self
.
config
.
mask_resolution
is
not
None
:
np_masks
=
np
.
array
(
outs
[
1
])
...
...
@@ -539,6 +558,13 @@ class Detector():
for
i
in
range
(
repeats
):
self
.
predictor
.
zero_copy_run
()
output_names
=
self
.
predictor
.
get_output_names
()
if
self
.
config
.
arch
==
'SOLOv2'
:
np_label
=
self
.
predictor
.
get_output_tensor
(
output_names
[
0
]).
copy_to_cpu
()
np_score
=
self
.
predictor
.
get_output_tensor
(
output_names
[
1
]).
copy_to_cpu
()
np_segms
=
self
.
predictor
.
get_output_tensor
(
output_names
[
2
]).
copy_to_cpu
()
boxes_tensor
=
self
.
predictor
.
get_output_tensor
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
if
self
.
config
.
mask_resolution
is
not
None
:
...
...
@@ -552,6 +578,9 @@ class Detector():
# do not perform postprocess in benchmark mode
results
=
[]
if
not
run_benchmark
:
if
self
.
config
.
arch
==
'SOLOv2'
:
return
dict
(
segm
=
np_segms
,
label
=
np_label
,
score
=
np_score
)
if
reduce
(
lambda
x
,
y
:
x
*
y
,
np_boxes
.
shape
)
<
6
:
print
(
'[WARNNING] No object detected.'
)
results
=
{
'boxes'
:
np
.
array
([])}
...
...
@@ -579,7 +608,8 @@ def predict_image():
results
,
detector
.
config
.
labels
,
mask_resolution
=
detector
.
config
.
mask_resolution
,
output_dir
=
FLAGS
.
output_dir
)
output_dir
=
FLAGS
.
output_dir
,
threshold
=
FLAGS
.
threshold
)
def
predict_video
(
camera_id
):
...
...
deploy/python/visualize.py
浏览文件 @
fbe1d120
...
...
@@ -18,20 +18,22 @@ from __future__ import division
import
cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
from
scipy
import
ndimage
def
visualize_box_mask
(
im
,
results
,
labels
,
mask_resolution
=
14
):
"""
def
visualize_box_mask
(
im
,
results
,
labels
,
mask_resolution
=
14
,
threshold
=
0.5
):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
,
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
if
isinstance
(
im
,
str
):
im
=
Image
.
open
(
im
).
convert
(
'RGB'
)
...
...
@@ -46,15 +48,23 @@ def visualize_box_mask(im, results, labels, mask_resolution=14):
resolution
=
mask_resolution
)
if
'boxes'
in
results
:
im
=
draw_box
(
im
,
results
[
'boxes'
],
labels
)
if
'segm'
in
results
:
im
=
draw_segm
(
im
,
results
[
'segm'
],
results
[
'label'
],
results
[
'score'
],
labels
,
threshold
=
threshold
)
return
im
def
get_color_map_list
(
num_classes
):
"""
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
color_map (list): RGB color list
"""
color_map
=
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
num_classes
):
...
...
@@ -71,9 +81,9 @@ def get_color_map_list(num_classes):
def
expand_boxes
(
boxes
,
scale
=
0.0
):
"""
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box
,
boxes (np.ndarray): shape:[N,4], N:number of box
,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
...
...
@@ -94,17 +104,17 @@ def expand_boxes(boxes, scale=0.0):
def
draw_mask
(
im
,
np_boxes
,
np_masks
,
labels
,
resolution
=
14
,
threshold
=
0.5
):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
labels (list): labels:['class1', ..., 'classn']
resolution (int): shape of a mask is:[resolution, resolution]
threshold (float): threshold of mask
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
color_list
=
get_color_map_list
(
len
(
labels
))
scale
=
(
resolution
+
2.0
)
/
resolution
...
...
@@ -149,14 +159,14 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
def
draw_box
(
im
,
np_boxes
,
labels
):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
draw_thickness
=
min
(
im
.
size
)
//
320
draw
=
ImageDraw
.
Draw
(
im
)
...
...
@@ -186,3 +196,43 @@ def draw_box(im, np_boxes, labels):
[(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
text
((
xmin
+
1
,
ymin
-
th
),
text
,
fill
=
(
255
,
255
,
255
))
return
im
def
draw_segm
(
im
,
np_segms
,
np_label
,
np_score
,
labels
,
threshold
=
0.5
,
alpha
=
0.7
):
"""
Draw segmentation on image
"""
mask_color_id
=
0
w_ratio
=
.
4
color_list
=
get_color_map_list
(
len
(
labels
))
im
=
np
.
array
(
im
).
astype
(
'float32'
)
clsid2color
=
{}
np_segms
=
np_segms
.
astype
(
np
.
uint8
)
for
i
in
range
(
np_segms
.
shape
[
0
]):
mask
,
score
,
clsid
=
np_segms
[
i
],
np_score
[
i
],
np_label
[
i
]
+
1
if
score
<
threshold
:
continue
if
clsid
not
in
clsid2color
:
clsid2color
[
clsid
]
=
color_list
[
clsid
]
color_mask
=
clsid2color
[
clsid
]
for
c
in
range
(
3
):
color_mask
[
c
]
=
color_mask
[
c
]
*
(
1
-
w_ratio
)
+
w_ratio
*
255
idx
=
np
.
nonzero
(
mask
)
color_mask
=
np
.
array
(
color_mask
)
im
[
idx
[
0
],
idx
[
1
],
:]
*=
1.0
-
alpha
im
[
idx
[
0
],
idx
[
1
],
:]
+=
alpha
*
color_mask
center_y
,
center_x
=
ndimage
.
measurements
.
center_of_mass
(
mask
)
label_text
=
"{}"
.
format
(
labels
[
clsid
])
print
(
label_text
)
print
(
center_y
,
center_x
)
vis_pos
=
(
max
(
int
(
center_x
)
-
10
,
0
),
int
(
center_y
))
cv2
.
putText
(
im
,
label_text
,
vis_pos
,
cv2
.
FONT_HERSHEY_COMPLEX
,
0.3
,
(
255
,
255
,
255
))
return
Image
.
fromarray
(
im
.
astype
(
'uint8'
))
ppdet/data/transform/batch_operators.py
浏览文件 @
fbe1d120
...
...
@@ -24,6 +24,7 @@ except Exception:
import
logging
import
cv2
import
numpy
as
np
from
scipy
import
ndimage
from
.operators
import
register_op
,
BaseOperator
from
.op_helper
import
jaccard_overlap
,
gaussian2D
...
...
@@ -37,6 +38,7 @@ __all__ = [
'Gt2YoloTarget'
,
'Gt2FCOSTarget'
,
'Gt2TTFTarget'
,
'Gt2Solov2Target'
,
]
...
...
@@ -88,6 +90,13 @@ class PadBatch(BaseOperator):
(
1
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_sem
[:,
:
im_h
,
:
im_w
]
=
semantic
data
[
'semantic'
]
=
padding_sem
if
'gt_segm'
in
data
.
keys
()
and
data
[
'gt_segm'
]
is
not
None
:
gt_segm
=
data
[
'gt_segm'
]
padding_segm
=
np
.
zeros
(
(
gt_segm
.
shape
[
0
],
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
uint8
)
padding_segm
[:,
:
im_h
,
:
im_w
]
=
gt_segm
data
[
'gt_segm'
]
=
padding_segm
return
samples
...
...
@@ -590,3 +599,154 @@ class Gt2TTFTarget(BaseOperator):
heatmap
[
y
-
top
:
y
+
bottom
,
x
-
left
:
x
+
right
]
=
np
.
maximum
(
masked_heatmap
,
masked_gaussian
)
return
heatmap
@
register_op
class
Gt2Solov2Target
(
BaseOperator
):
"""Assign mask target and labels in SOLOv2 network.
Args:
num_grids (list): The list of feature map grids size.
scale_ranges (list): The list of mask boundary range.
coord_sigma (float): The coefficient of coordinate area length.
sampling_ratio (float): The ratio of down sampling.
"""
def
__init__
(
self
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
scale_ranges
=
[[
1
,
96
],
[
48
,
192
],
[
96
,
384
],
[
192
,
768
],
[
384
,
2048
]],
coord_sigma
=
0.2
,
sampling_ratio
=
4.0
):
super
(
Gt2Solov2Target
,
self
).
__init__
()
self
.
num_grids
=
num_grids
self
.
scale_ranges
=
scale_ranges
self
.
coord_sigma
=
coord_sigma
self
.
sampling_ratio
=
sampling_ratio
def
_scale_size
(
self
,
im
,
scale
):
h
,
w
=
im
.
shape
[:
2
]
new_size
=
(
int
(
w
*
float
(
scale
)
+
0.5
),
int
(
h
*
float
(
scale
)
+
0.5
))
resized_img
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
scale
,
fy
=
scale
,
interpolation
=
cv2
.
INTER_LINEAR
)
return
resized_img
def
__call__
(
self
,
samples
,
context
=
None
):
for
sample
in
samples
:
gt_bboxes_raw
=
sample
[
'gt_bbox'
]
gt_labels_raw
=
sample
[
'gt_class'
]
im_c
,
im_h
,
im_w
=
sample
[
'image'
].
shape
[:]
gt_masks_raw
=
sample
[
'gt_segm'
].
astype
(
np
.
uint8
)
mask_feat_size
=
[
int
(
im_h
/
self
.
sampling_ratio
),
int
(
im_w
/
self
.
sampling_ratio
)
]
gt_areas
=
np
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_ind_label_list
=
[]
grid_offset
=
[]
idx
=
0
for
(
lower_bound
,
upper_bound
),
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
num_grids
):
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
()[
0
]
num_ins
=
len
(
hit_indices
)
ins_label
=
[]
grid_order
=
[]
cate_label
=
np
.
zeros
([
num_grid
,
num_grid
],
dtype
=
np
.
int64
)
ins_ind_label
=
np
.
zeros
([
num_grid
**
2
],
dtype
=
np
.
bool
)
if
num_ins
==
0
:
ins_label
=
np
.
zeros
(
[
1
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
np
.
uint8
)
ins_ind_label_list
.
append
(
ins_ind_label
)
sample
[
'cate_label{}'
.
format
(
idx
)]
=
cate_label
.
flatten
()
sample
[
'ins_label{}'
.
format
(
idx
)]
=
ins_label
sample
[
'grid_order{}'
.
format
(
idx
)]
=
np
.
asarray
([
0
])
grid_offset
.
append
(
1
)
idx
+=
1
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
,
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
coord_sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
coord_sigma
for
seg_mask
,
gt_label
,
half_h
,
half_w
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
):
if
seg_mask
.
sum
()
==
0
:
continue
# mass center
upsampled_size
=
(
mask_feat_size
[
0
]
*
4
,
mask_feat_size
[
1
]
*
4
)
center_h
,
center_w
=
ndimage
.
measurements
.
center_of_mass
(
seg_mask
)
coord_w
=
int
(
(
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
(
(
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
seg_mask
=
self
.
_scale_size
(
seg_mask
,
scale
=
1.
/
self
.
sampling_ratio
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
cur_ins_label
=
np
.
zeros
(
[
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
np
.
uint8
)
cur_ins_label
[:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_label
.
append
(
cur_ins_label
)
ins_ind_label
[
label
]
=
True
grid_order
.
append
(
label
)
if
ins_label
==
[]:
ins_label
=
np
.
zeros
(
[
1
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
np
.
uint8
)
ins_ind_label_list
.
append
(
ins_ind_label
)
sample
[
'cate_label{}'
.
format
(
idx
)]
=
cate_label
.
flatten
()
sample
[
'ins_label{}'
.
format
(
idx
)]
=
ins_label
sample
[
'grid_order{}'
.
format
(
idx
)]
=
np
.
asarray
([
0
])
grid_offset
.
append
(
1
)
else
:
ins_label
=
np
.
stack
(
ins_label
,
axis
=
0
)
ins_ind_label_list
.
append
(
ins_ind_label
)
sample
[
'cate_label{}'
.
format
(
idx
)]
=
cate_label
.
flatten
()
sample
[
'ins_label{}'
.
format
(
idx
)]
=
ins_label
sample
[
'grid_order{}'
.
format
(
idx
)]
=
np
.
asarray
(
grid_order
)
assert
len
(
grid_order
)
>
0
grid_offset
.
append
(
len
(
grid_order
))
idx
+=
1
ins_ind_labels
=
np
.
concatenate
([
ins_ind_labels_level_img
for
ins_ind_labels_level_img
in
ins_ind_label_list
])
fg_num
=
np
.
sum
(
ins_ind_labels
)
sample
[
'fg_num'
]
=
fg_num
sample
[
'grid_offset'
]
=
np
.
asarray
(
grid_offset
).
astype
(
np
.
int32
)
return
samples
ppdet/data/transform/operators.py
浏览文件 @
fbe1d120
...
...
@@ -272,7 +272,8 @@ class ResizeImage(BaseOperator):
target_size
=
0
,
max_size
=
0
,
interp
=
cv2
.
INTER_LINEAR
,
use_cv2
=
True
):
use_cv2
=
True
,
resize_box
=
False
):
"""
Rescale image to the specified target size, and capped at max_size
if max_size != 0.
...
...
@@ -285,11 +286,13 @@ class ResizeImage(BaseOperator):
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
resize_box (bool): whether resize ground truth bbox annotations.
"""
super
(
ResizeImage
,
self
).
__init__
()
self
.
max_size
=
int
(
max_size
)
self
.
interp
=
int
(
interp
)
self
.
use_cv2
=
use_cv2
self
.
resize_box
=
resize_box
if
not
(
isinstance
(
target_size
,
int
)
or
isinstance
(
target_size
,
list
)):
raise
TypeError
(
"Type of target_size is invalid. Must be Integer or List, now is {}"
.
...
...
@@ -348,18 +351,6 @@ class ResizeImage(BaseOperator):
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
if
'semantic'
in
sample
.
keys
()
and
sample
[
'semantic'
]
is
not
None
:
semantic
=
sample
[
'semantic'
]
semantic
=
cv2
.
resize
(
semantic
.
astype
(
'float32'
),
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
semantic
=
np
.
asarray
(
semantic
).
astype
(
'int32'
)
semantic
=
np
.
expand_dims
(
semantic
,
0
)
sample
[
'semantic'
]
=
semantic
else
:
if
self
.
max_size
!=
0
:
raise
TypeError
(
...
...
@@ -370,6 +361,38 @@ class ResizeImage(BaseOperator):
im
=
im
.
resize
((
int
(
resize_w
),
int
(
resize_h
)),
self
.
interp
)
im
=
np
.
array
(
im
)
sample
[
'image'
]
=
im
sample
[
'scale_factor'
]
=
[
im_scale_x
,
im_scale_y
]
*
2
if
'gt_bbox'
in
sample
and
self
.
resize_box
and
len
(
sample
[
'gt_bbox'
])
>
0
:
bboxes
=
sample
[
'gt_bbox'
]
*
sample
[
'scale_factor'
]
bboxes
[:,
0
::
2
]
=
np
.
clip
(
bboxes
[:,
0
::
2
],
0
,
resize_w
-
1
)
bboxes
[:,
1
::
2
]
=
np
.
clip
(
bboxes
[:,
1
::
2
],
0
,
resize_h
-
1
)
sample
[
'gt_bbox'
]
=
bboxes
if
'semantic'
in
sample
.
keys
()
and
sample
[
'semantic'
]
is
not
None
:
semantic
=
sample
[
'semantic'
]
semantic
=
cv2
.
resize
(
semantic
.
astype
(
'float32'
),
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
semantic
=
np
.
asarray
(
semantic
).
astype
(
'int32'
)
semantic
=
np
.
expand_dims
(
semantic
,
0
)
sample
[
'semantic'
]
=
semantic
if
'gt_segm'
in
sample
and
len
(
sample
[
'gt_segm'
])
>
0
:
masks
=
[
cv2
.
resize
(
gt_segm
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_NEAREST
)
for
gt_segm
in
sample
[
'gt_segm'
]
]
sample
[
'gt_segm'
]
=
np
.
asarray
(
masks
).
astype
(
np
.
uint8
)
return
sample
...
...
@@ -473,7 +496,6 @@ class RandomFlipImage(BaseOperator):
if
self
.
is_mask_flip
and
len
(
sample
[
'gt_poly'
])
!=
0
:
sample
[
'gt_poly'
]
=
self
.
flip_segms
(
sample
[
'gt_poly'
],
height
,
width
)
if
'gt_keypoint'
in
sample
.
keys
():
sample
[
'gt_keypoint'
]
=
self
.
flip_keypoint
(
sample
[
'gt_keypoint'
],
width
)
...
...
@@ -482,6 +504,9 @@ class RandomFlipImage(BaseOperator):
'semantic'
]
is
not
None
:
sample
[
'semantic'
]
=
sample
[
'semantic'
][:,
::
-
1
]
if
'gt_segm'
in
sample
.
keys
()
and
sample
[
'gt_segm'
]
is
not
None
:
sample
[
'gt_segm'
]
=
sample
[
'gt_segm'
][:,
:,
::
-
1
]
sample
[
'flipped'
]
=
True
sample
[
'image'
]
=
im
sample
=
samples
if
batch_input
else
samples
[
0
]
...
...
@@ -2557,3 +2582,41 @@ class DebugVisibleImage(BaseOperator):
save_path
=
os
.
path
.
join
(
self
.
output_dir
,
out_file_name
)
image
.
save
(
save_path
,
quality
=
95
)
return
sample
@
register_op
class
Poly2Mask
(
BaseOperator
):
"""
gt poly to mask annotations
"""
def
__init__
(
self
):
super
(
Poly2Mask
,
self
).
__init__
()
import
pycocotools.mask
as
maskUtils
self
.
maskutils
=
maskUtils
def
_poly2mask
(
self
,
mask_ann
,
img_h
,
img_w
):
if
isinstance
(
mask_ann
,
list
):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles
=
self
.
maskutils
.
frPyObjects
(
mask_ann
,
img_h
,
img_w
)
rle
=
self
.
maskutils
.
merge
(
rles
)
elif
isinstance
(
mask_ann
[
'counts'
],
list
):
# uncompressed RLE
rle
=
self
.
maskutils
.
frPyObjects
(
mask_ann
,
img_h
,
img_w
)
else
:
# rle
rle
=
mask_ann
mask
=
self
.
maskutils
.
decode
(
rle
)
return
mask
def
__call__
(
self
,
sample
,
context
=
None
):
assert
'gt_poly'
in
sample
im_h
=
sample
[
'h'
]
im_w
=
sample
[
'w'
]
masks
=
[
self
.
_poly2mask
(
gt_poly
,
im_h
,
im_w
)
for
gt_poly
in
sample
[
'gt_poly'
]
]
sample
[
'gt_segm'
]
=
np
.
asarray
(
masks
).
astype
(
np
.
uint8
)
return
sample
ppdet/modeling/__init__.py
浏览文件 @
fbe1d120
...
...
@@ -22,6 +22,7 @@ from . import roi_extractors
from
.
import
roi_heads
from
.
import
ops
from
.
import
target_assigners
from
.
import
mask_head
from
.anchor_heads
import
*
from
.architectures
import
*
...
...
@@ -30,3 +31,4 @@ from .roi_extractors import *
from
.roi_heads
import
*
from
.ops
import
*
from
.target_assigners
import
*
from
.mask_head
import
*
ppdet/modeling/anchor_heads/__init__.py
浏览文件 @
fbe1d120
...
...
@@ -21,6 +21,7 @@ from . import fcos_head
from
.
import
corner_head
from
.
import
efficient_head
from
.
import
ttf_head
from
.
import
solov2_head
from
.rpn_head
import
*
from
.yolo_head
import
*
...
...
@@ -29,3 +30,4 @@ from .fcos_head import *
from
.corner_head
import
*
from
.efficient_head
import
*
from
.ttf_head
import
*
from
.solov2_head
import
*
ppdet/modeling/anchor_heads/solov2_head.py
0 → 100644
浏览文件 @
fbe1d120
此差异已折叠。
点击以展开。
ppdet/modeling/architectures/__init__.py
浏览文件 @
fbe1d120
...
...
@@ -29,6 +29,7 @@ from . import fcos
from
.
import
cornernet_squeeze
from
.
import
ttfnet
from
.
import
htc
from
.
import
solov2
from
.faster_rcnn
import
*
from
.mask_rcnn
import
*
...
...
@@ -45,3 +46,4 @@ from .fcos import *
from
.cornernet_squeeze
import
*
from
.ttfnet
import
*
from
.htc
import
*
from
.solov2
import
*
ppdet/modeling/architectures/solov2.py
0 → 100644
浏览文件 @
fbe1d120
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
from
paddle
import
fluid
from
ppdet.experimental
import
mixed_precision_global_state
from
ppdet.core.workspace
import
register
__all__
=
[
'SOLOv2'
]
@
register
class
SOLOv2
(
object
):
"""
SOLOv2 network, see https://arxiv.org/abs/2003.10152
Args:
backbone (object): an backbone instance
fpn (object): feature pyramid network instance
bbox_head (object): an `SOLOv2Head` instance
mask_head (object): an `SOLOv2MaskHead` instance
batch_size (int): batch size.
"""
__category__
=
'architecture'
__inject__
=
[
'backbone'
,
'fpn'
,
'bbox_head'
,
'mask_head'
]
def
__init__
(
self
,
backbone
,
fpn
=
None
,
bbox_head
=
'SOLOv2Head'
,
mask_head
=
'SOLOv2MaskHead'
,
batch_size
=
1
):
super
(
SOLOv2
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
fpn
=
fpn
self
.
bbox_head
=
bbox_head
self
.
mask_head
=
mask_head
self
.
batch_size
=
batch_size
def
build
(
self
,
feed_vars
,
mode
=
'train'
):
im
=
feed_vars
[
'image'
]
mixed_precision_enabled
=
mixed_precision_global_state
()
is
not
None
# cast inputs to FP16
if
mixed_precision_enabled
:
im
=
fluid
.
layers
.
cast
(
im
,
'float16'
)
body_feats
=
self
.
backbone
(
im
)
if
self
.
fpn
is
not
None
:
body_feats
,
spatial_scale
=
self
.
fpn
.
get_output
(
body_feats
)
if
isinstance
(
body_feats
,
OrderedDict
):
body_feat_names
=
list
(
body_feats
.
keys
())
body_feats
=
[
body_feats
[
name
]
for
name
in
body_feat_names
]
# cast features back to FP32
if
mixed_precision_enabled
:
body_feats
=
[
fluid
.
layers
.
cast
(
v
,
'float32'
)
for
v
in
body_feats
]
if
not
mode
==
'train'
:
self
.
batch_size
=
1
mask_feat_pred
=
self
.
mask_head
.
get_output
(
body_feats
,
self
.
batch_size
)
if
mode
==
'train'
:
ins_labels
=
[]
cate_labels
=
[]
grid_orders
=
[]
fg_num
=
feed_vars
[
'fg_num'
]
grid_offset
=
feed_vars
[
'grid_offset'
]
for
i
in
range
(
5
):
ins_label
=
'ins_label{}'
.
format
(
i
)
if
ins_label
in
feed_vars
:
ins_labels
.
append
(
feed_vars
[
ins_label
])
cate_label
=
'cate_label{}'
.
format
(
i
)
if
cate_label
in
feed_vars
:
cate_labels
.
append
(
feed_vars
[
cate_label
])
grid_order
=
'grid_order{}'
.
format
(
i
)
if
grid_order
in
feed_vars
:
grid_orders
.
append
(
feed_vars
[
grid_order
])
cate_preds
,
kernel_preds
=
self
.
bbox_head
.
get_outputs
(
body_feats
,
batch_size
=
self
.
batch_size
)
losses
=
self
.
bbox_head
.
get_loss
(
cate_preds
,
kernel_preds
,
mask_feat_pred
,
ins_labels
,
cate_labels
,
grid_orders
,
fg_num
,
grid_offset
,
self
.
batch_size
)
total_loss
=
fluid
.
layers
.
sum
(
list
(
losses
.
values
()))
losses
.
update
({
'loss'
:
total_loss
})
return
losses
else
:
im_info
=
feed_vars
[
'im_info'
]
outs
=
self
.
bbox_head
.
get_outputs
(
body_feats
,
is_eval
=
True
,
batch_size
=
self
.
batch_size
)
seg_inputs
=
outs
+
(
mask_feat_pred
,
im_info
)
return
self
.
bbox_head
.
get_prediction
(
*
seg_inputs
)
def
_inputs_def
(
self
,
image_shape
,
fields
):
im_shape
=
[
None
]
+
image_shape
# yapf: disable
inputs_def
=
{
'image'
:
{
'shape'
:
im_shape
,
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'im_info'
:
{
'shape'
:
[
None
,
3
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'im_id'
:
{
'shape'
:
[
None
,
1
],
'dtype'
:
'int64'
,
'lod_level'
:
0
},
'im_shape'
:
{
'shape'
:
[
None
,
3
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
}
if
'gt_segm'
in
fields
:
targets_def
=
{
'ins_label0'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label1'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label2'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label3'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label4'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label0'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label1'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label2'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label3'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label4'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order0'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order1'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order2'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order3'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order4'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'fg_num'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
0
},
'grid_offset'
:
{
'shape'
:
[
None
,
5
],
'dtype'
:
'int32'
,
'lod_level'
:
0
},
}
# yapf: enable
inputs_def
.
update
(
targets_def
)
return
inputs_def
def
build_inputs
(
self
,
image_shape
=
[
3
,
None
,
None
],
fields
=
[
'image'
,
'im_id'
,
'gt_segm'
],
# for train
use_dataloader
=
True
,
iterable
=
False
):
inputs_def
=
self
.
_inputs_def
(
image_shape
,
fields
)
if
'gt_segm'
in
fields
:
fields
.
remove
(
'gt_segm'
)
fields
.
extend
([
'fg_num'
,
'grid_offset'
])
for
i
in
range
(
5
):
fields
.
extend
([
'ins_label%d'
%
i
,
'cate_label%d'
%
i
,
'grid_order%d'
%
i
])
feed_vars
=
OrderedDict
([(
key
,
fluid
.
data
(
name
=
key
,
shape
=
inputs_def
[
key
][
'shape'
],
dtype
=
inputs_def
[
key
][
'dtype'
],
lod_level
=
inputs_def
[
key
][
'lod_level'
]))
for
key
in
fields
])
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
list
(
feed_vars
.
values
()),
capacity
=
16
,
use_double_buffer
=
True
,
iterable
=
iterable
)
if
use_dataloader
else
None
return
feed_vars
,
loader
def
train
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'train'
)
def
eval
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
def
test
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
ppdet/modeling/mask_head/__init__.py
0 → 100644
浏览文件 @
fbe1d120
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
.
import
solo_mask_head
from
.solo_mask_head
import
*
ppdet/modeling/mask_head/solo_mask_head.py
0 → 100644
浏览文件 @
fbe1d120
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
fluid
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
ConvNorm
,
DeformConvNorm
__all__
=
[
'SOLOv2MaskHead'
]
@
register
class
SOLOv2MaskHead
(
object
):
"""
SOLOv2MaskHead
Args:
out_channels (int): The channel number of output variable.
start_level (int): The position where the input starts.
end_level (int): The position where the input ends.
num_classes (int): Number of classes in SOLOv2MaskHead output.
use_dcn_in_tower: Whether to use dcn in tower or not.
"""
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
81
,
use_dcn_in_tower
=
False
):
super
(
SOLOv2MaskHead
,
self
).
__init__
()
assert
start_level
>=
0
and
end_level
>=
start_level
self
.
out_channels
=
out_channels
self
.
start_level
=
start_level
self
.
end_level
=
end_level
self
.
num_classes
=
num_classes
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
conv_type
=
[
ConvNorm
,
DeformConvNorm
]
def
_convs_levels
(
self
,
conv_feat
,
level
,
name
=
None
):
conv_func
=
self
.
conv_type
[
0
]
if
self
.
use_dcn_in_tower
:
conv_func
=
self
.
conv_type
[
1
]
if
level
==
0
:
return
conv_func
(
input
=
conv_feat
,
num_filters
=
self
.
out_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
,
norm_groups
=
32
,
freeze_norm
=
False
,
act
=
'relu'
,
initializer
=
fluid
.
initializer
.
NormalInitializer
(
scale
=
0.01
),
norm_name
=
name
+
'.conv'
+
str
(
level
)
+
'.gn'
,
name
=
name
+
'.conv'
+
str
(
level
))
for
j
in
range
(
level
):
conv_feat
=
conv_func
(
input
=
conv_feat
,
num_filters
=
self
.
out_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
,
norm_groups
=
32
,
freeze_norm
=
False
,
act
=
'relu'
,
initializer
=
fluid
.
initializer
.
NormalInitializer
(
scale
=
0.01
),
norm_name
=
name
+
'.conv'
+
str
(
j
)
+
'.gn'
,
name
=
name
+
'.conv'
+
str
(
j
))
conv_feat
=
fluid
.
layers
.
resize_bilinear
(
conv_feat
,
scale
=
2
,
name
=
'upsample'
+
str
(
level
)
+
str
(
j
),
align_corners
=
False
,
align_mode
=
0
)
return
conv_feat
def
_conv_pred
(
self
,
conv_feat
):
conv_func
=
self
.
conv_type
[
0
]
if
self
.
use_dcn_in_tower
:
conv_func
=
self
.
conv_type
[
1
]
conv_feat
=
conv_func
(
input
=
conv_feat
,
num_filters
=
self
.
num_classes
,
filter_size
=
1
,
stride
=
1
,
norm_type
=
'gn'
,
norm_groups
=
32
,
freeze_norm
=
False
,
act
=
'relu'
,
initializer
=
fluid
.
initializer
.
NormalInitializer
(
scale
=
0.01
),
norm_name
=
'mask_feat_head.conv_pred.0.gn'
,
name
=
'mask_feat_head.conv_pred.0'
)
return
conv_feat
def
get_output
(
self
,
inputs
,
batch_size
=
1
):
"""
Get SOLOv2MaskHead output.
Args:
inputs(list[Variable]): feature map from each necks with shape of [N, C, H, W]
batch_size (int): batch size
Returns:
ins_pred(Variable): Output of SOLOv2MaskHead head
"""
range_level
=
self
.
end_level
-
self
.
start_level
+
1
feature_add_all_level
=
self
.
_convs_levels
(
inputs
[
0
],
0
,
name
=
'mask_feat_head.convs_all_levels.0'
)
for
i
in
range
(
1
,
range_level
):
input_p
=
inputs
[
i
]
if
i
==
3
:
input_feat
=
input_p
x_range
=
paddle
.
linspace
(
-
1
,
1
,
fluid
.
layers
.
shape
(
input_feat
)[
-
1
],
dtype
=
'float32'
)
y_range
=
paddle
.
linspace
(
-
1
,
1
,
fluid
.
layers
.
shape
(
input_feat
)[
-
2
],
dtype
=
'float32'
)
y
,
x
=
paddle
.
tensor
.
meshgrid
([
y_range
,
x_range
])
x
=
fluid
.
layers
.
unsqueeze
(
x
,
[
0
,
1
])
y
=
fluid
.
layers
.
unsqueeze
(
y
,
[
0
,
1
])
y
=
fluid
.
layers
.
expand
(
y
,
expand_times
=
[
batch_size
,
1
,
1
,
1
])
x
=
fluid
.
layers
.
expand
(
x
,
expand_times
=
[
batch_size
,
1
,
1
,
1
])
coord_feat
=
fluid
.
layers
.
concat
([
x
,
y
],
axis
=
1
)
input_p
=
fluid
.
layers
.
concat
([
input_p
,
coord_feat
],
axis
=
1
)
feature_add_all_level
=
fluid
.
layers
.
elementwise_add
(
feature_add_all_level
,
self
.
_convs_levels
(
input_p
,
i
,
name
=
'mask_feat_head.convs_all_levels.{}'
.
format
(
i
)))
ins_pred
=
self
.
_conv_pred
(
feature_add_all_level
)
return
ins_pred
ppdet/modeling/ops.py
浏览文件 @
fbe1d120
...
...
@@ -17,6 +17,7 @@ from numbers import Integral
import
math
import
six
import
paddle
from
paddle
import
fluid
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.initializer
import
NumpyArrayInitializer
...
...
@@ -1263,27 +1264,27 @@ class LibraBBoxAssigner(object):
rois
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'rois',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
4
],
)
bbox_inside_weights
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'bbox_inside_weights',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
8
if
self
.
is_cls_agnostic
else
self
.
class_nums
*
4
],
)
bbox_outside_weights
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'bbox_outside_weights',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
8
if
self
.
is_cls_agnostic
else
self
.
class_nums
*
4
],
)
bbox_targets
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'bbox_targets',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
8
if
self
.
is_cls_agnostic
else
self
.
class_nums
*
4
],
)
labels_int32
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'labels_int32',
name
=
None
,
dtype
=
'int32'
,
shape
=
[
-
1
,
1
],
)
...
...
@@ -1565,3 +1566,79 @@ class RetinaOutputDecoder(object):
self
.
nms_top_k
=
pre_nms_top_n
self
.
keep_top_k
=
detections_per_im
self
.
nms_eta
=
nms_eta
@
register
@
serializable
class
MaskMatrixNMS
(
object
):
"""
Matrix NMS for multi-class masks.
Args:
kernel (str): 'linear' or 'gaussian'
sigma (float): std in gaussian method
Input:
seg_masks (Variable): shape (n, h, w), segmentation feature maps
cate_labels (Variable): shape (n), mask labels in descending order
cate_scores (Variable): shape (n), mask scores in descending order
sum_masks (Variable): The sum of seg_masks
Returns:
Variable: cate_scores_update, tensors of shape (n)
"""
def
__init__
(
self
,
kernel
=
'gaussian'
,
sigma
=
2.0
):
super
(
MaskMatrixNMS
,
self
).
__init__
()
self
.
kernel
=
kernel
self
.
sigma
=
sigma
def
__call__
(
self
,
seg_masks
,
cate_labels
,
cate_scores
,
sum_masks
=
None
):
n_samples
=
fluid
.
layers
.
shape
(
cate_labels
)
seg_masks
=
fluid
.
layers
.
reshape
(
seg_masks
,
shape
=
(
n_samples
,
-
1
))
# inter.
inter_matrix
=
paddle
.
mm
(
seg_masks
,
fluid
.
layers
.
transpose
(
seg_masks
,
[
1
,
0
]))
# union.
sum_masks_x
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
sum_masks
,
expand_times
=
[
n_samples
]),
shape
=
[
n_samples
,
n_samples
])
# iou.
iou_matrix
=
(
inter_matrix
/
(
sum_masks_x
+
fluid
.
layers
.
transpose
(
sum_masks_x
,
[
1
,
0
])
-
inter_matrix
))
iou_matrix
=
paddle
.
tensor
.
triu
(
iou_matrix
,
diagonal
=
1
)
# label_specific matrix.
cate_labels_x
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
cate_labels
,
expand_times
=
[
n_samples
]),
shape
=
[
n_samples
,
n_samples
])
label_matrix
=
fluid
.
layers
.
cast
(
(
cate_labels_x
==
fluid
.
layers
.
transpose
(
cate_labels_x
,
[
1
,
0
])),
'float32'
)
label_matrix
=
paddle
.
tensor
.
triu
(
label_matrix
,
diagonal
=
1
)
# IoU compensation
compensate_iou
=
paddle
.
max
((
iou_matrix
*
label_matrix
),
axis
=
0
)
compensate_iou
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
compensate_iou
,
expand_times
=
[
n_samples
]),
shape
=
[
n_samples
,
n_samples
])
compensate_iou
=
fluid
.
layers
.
transpose
(
compensate_iou
,
[
1
,
0
])
# IoU decay
decay_iou
=
iou_matrix
*
label_matrix
# matrix nms
if
self
.
kernel
==
'gaussian'
:
decay_matrix
=
fluid
.
layers
.
exp
(
-
1
*
self
.
sigma
*
(
decay_iou
**
2
))
compensate_matrix
=
fluid
.
layers
.
exp
(
-
1
*
self
.
sigma
*
(
compensate_iou
**
2
))
decay_coefficient
=
paddle
.
min
(
decay_matrix
/
compensate_matrix
,
axis
=
0
)
elif
self
.
kernel
==
'linear'
:
decay_matrix
=
(
1
-
decay_iou
)
/
(
1
-
compensate_iou
)
decay_coefficient
=
paddle
.
min
(
decay_matrix
,
axis
=
0
)
else
:
raise
NotImplementedError
# update the score.
cate_scores_update
=
cate_scores
*
decay_coefficient
return
cate_scores_update
ppdet/utils/coco_eval.py
浏览文件 @
fbe1d120
...
...
@@ -164,6 +164,47 @@ def mask_eval(results,
cocoapi_eval
(
outfile
,
'segm'
,
coco_gt
=
coco_gt
)
def
segm_eval
(
results
,
anno_file
,
outfile
,
save_only
=
False
):
assert
'segm'
in
results
[
0
]
assert
outfile
.
endswith
(
'.json'
)
from
pycocotools.coco
import
COCO
coco_gt
=
COCO
(
anno_file
)
clsid2catid
=
{
i
:
v
for
i
,
v
in
enumerate
(
coco_gt
.
getCatIds
())}
segm_results
=
[]
for
t
in
results
:
im_id
=
int
(
t
[
'im_id'
][
0
][
0
])
segs
=
t
[
'segm'
]
for
mask
in
segs
:
catid
=
int
(
clsid2catid
[
mask
[
0
]])
masks
=
mask
[
1
]
mask_score
=
masks
[
1
]
segm
=
masks
[
0
]
segm
[
'counts'
]
=
segm
[
'counts'
].
decode
(
'utf8'
)
coco_res
=
{
'image_id'
:
im_id
,
'category_id'
:
catid
,
'segmentation'
:
segm
,
'score'
:
mask_score
}
segm_results
.
append
(
coco_res
)
if
len
(
segm_results
)
==
0
:
logger
.
warning
(
"The number of valid mask detected is zero.
\n
\
Please use reasonable model and check input data."
)
return
with
open
(
outfile
,
'w'
)
as
f
:
json
.
dump
(
segm_results
,
f
)
if
save_only
:
logger
.
info
(
'The mask result is saved to {} and do not '
'evaluate the mAP.'
.
format
(
outfile
))
return
map_stats
=
cocoapi_eval
(
outfile
,
'segm'
,
coco_gt
=
coco_gt
)
return
map_stats
def
cocoapi_eval
(
jsonfile
,
style
,
coco_gt
=
None
,
...
...
@@ -374,6 +415,43 @@ def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
return
segm_res
def
segm2out
(
results
,
clsid2catid
,
thresh_binarize
=
0.5
):
import
pycocotools.mask
as
mask_util
segm_res
=
[]
# for each batch
for
t
in
results
:
segms
=
t
[
'segm'
][
0
]
clsid_labels
=
t
[
'cate_label'
][
0
]
clsid_scores
=
t
[
'cate_score'
][
0
]
lengths
=
segms
.
shape
[
0
]
im_id
=
int
(
t
[
'im_id'
][
0
][
0
])
im_shape
=
t
[
'im_shape'
][
0
][
0
]
if
lengths
==
0
or
segms
is
None
:
continue
# for each sample
for
i
in
range
(
lengths
-
1
):
im_h
=
int
(
im_shape
[
0
])
im_w
=
int
(
im_shape
[
1
])
clsid
=
int
(
clsid_labels
[
i
])
catid
=
clsid2catid
[
clsid
]
score
=
clsid_scores
[
i
]
mask
=
segms
[
i
]
segm
=
mask_util
.
encode
(
np
.
array
(
mask
[:,
:,
np
.
newaxis
],
order
=
'F'
))[
0
]
segm
[
'counts'
]
=
segm
[
'counts'
].
decode
(
'utf8'
)
coco_res
=
{
'image_id'
:
im_id
,
'category_id'
:
catid
,
'segmentation'
:
segm
,
'score'
:
score
}
segm_res
.
append
(
coco_res
)
return
segm_res
def
expand_boxes
(
boxes
,
scale
):
"""
Expand an array of boxes by a given scale.
...
...
ppdet/utils/eval_utils.py
浏览文件 @
fbe1d120
...
...
@@ -94,6 +94,25 @@ def clean_res(result, keep_name_list):
return
clean_result
def
get_masks
(
result
):
import
pycocotools.mask
as
mask_util
if
result
is
None
:
return
{}
seg_pred
=
result
[
'segm'
][
0
].
astype
(
np
.
uint8
)
cate_label
=
result
[
'cate_label'
][
0
].
astype
(
np
.
int
)
cate_score
=
result
[
'cate_score'
][
0
].
astype
(
np
.
float
)
num_ins
=
seg_pred
.
shape
[
0
]
masks
=
[]
for
idx
in
range
(
num_ins
-
1
):
cur_mask
=
seg_pred
[
idx
,
...]
rle
=
mask_util
.
encode
(
np
.
array
(
cur_mask
[:,
:,
np
.
newaxis
],
order
=
'F'
))[
0
]
rst
=
(
rle
,
cate_score
[
idx
])
masks
.
append
([
cate_label
[
idx
],
rst
])
return
masks
def
eval_run
(
exe
,
compile_program
,
loader
,
...
...
@@ -163,11 +182,13 @@ def eval_run(exe,
corner_post_process
(
res
,
post_config
,
cfg
.
num_classes
)
if
'TTFNet'
in
cfg
.
architecture
:
res
[
'bbox'
][
1
].
append
([
len
(
res
[
'bbox'
][
0
])])
if
'segm'
in
res
:
res
[
'segm'
]
=
get_masks
(
res
)
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
if
len
(
res
[
'bbox'
][
1
])
==
0
:
if
'bbox'
not
in
res
or
len
(
res
[
'bbox'
][
1
])
==
0
:
has_bbox
=
False
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
except
(
StopIteration
,
fluid
.
core
.
EOFException
):
...
...
@@ -198,7 +219,7 @@ def eval_results(results,
"""Evaluation for evaluation program results"""
box_ap_stats
=
[]
if
metric
==
'COCO'
:
from
ppdet.utils.coco_eval
import
proposal_eval
,
bbox_eval
,
mask_eval
from
ppdet.utils.coco_eval
import
proposal_eval
,
bbox_eval
,
mask_eval
,
segm_eval
anno_file
=
dataset
.
get_anno
()
with_background
=
dataset
.
with_background
if
'proposal'
in
results
[
0
]:
...
...
@@ -225,6 +246,14 @@ def eval_results(results,
output
=
os
.
path
.
join
(
output_directory
,
'mask.json'
)
mask_eval
(
results
,
anno_file
,
output
,
resolution
,
save_only
=
save_only
)
if
'segm'
in
results
[
0
]:
output
=
'segm.json'
if
output_directory
:
output
=
os
.
path
.
join
(
output_directory
,
output
)
mask_ap_stats
=
segm_eval
(
results
,
anno_file
,
output
,
save_only
=
save_only
)
if
len
(
box_ap_stats
)
==
0
:
box_ap_stats
=
mask_ap_stats
else
:
if
'accum_map'
in
results
[
-
1
]:
res
=
np
.
mean
(
results
[
-
1
][
'accum_map'
][
0
])
...
...
tools/eval.py
浏览文件 @
fbe1d120
...
...
@@ -132,9 +132,6 @@ def main():
extra_keys
)
sub_eval_prog
=
sub_eval_prog
.
clone
(
True
)
#if 'weights' in cfg:
# checkpoint.load_params(exe, sub_eval_prog, cfg.weights)
# load model
exe
.
run
(
startup_prog
)
if
'weights'
in
cfg
:
...
...
@@ -146,7 +143,6 @@ def main():
results
=
eval_run
(
exe
,
compile_program
,
loader
,
keys
,
values
,
cls
,
cfg
,
sub_eval_prog
,
sub_keys
,
sub_values
,
resolution
)
#print(cfg['EvalReader']['dataset'].__dict__)
# evaluation
# if map_type not set, use default 11point, only use in VOC eval
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
...
...
tools/export_model.py
浏览文件 @
fbe1d120
...
...
@@ -46,11 +46,13 @@ TRT_MIN_SUBGRAPH = {
'Face'
:
3
,
'TTFNet'
:
3
,
'FCOS'
:
3
,
'SOLOv2'
:
3
,
}
RESIZE_SCALE_SET
=
{
'RCNN'
,
'RetinaNet'
,
'FCOS'
,
'SOLOv2'
,
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录