Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
fbe1d120
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fbe1d120
编写于
9月 15, 2020
作者:
S
still-wait
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add solov2 model
上级
a8103906
变更
19
展开全部
隐藏空白更改
内联
并排
Showing
19 changed file
with
1612 addition
and
49 deletion
+1612
-49
configs/solov2/README.md
configs/solov2/README.md
+22
-0
configs/solov2/solov2_r50_fpn_1x.yml
configs/solov2/solov2_r50_fpn_1x.yml
+62
-0
configs/solov2/solov2_reader.yml
configs/solov2/solov2_reader.yml
+99
-0
deploy/python/infer.py
deploy/python/infer.py
+38
-8
deploy/python/visualize.py
deploy/python/visualize.py
+66
-16
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+160
-0
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+77
-14
ppdet/modeling/__init__.py
ppdet/modeling/__init__.py
+2
-0
ppdet/modeling/anchor_heads/__init__.py
ppdet/modeling/anchor_heads/__init__.py
+2
-0
ppdet/modeling/anchor_heads/solov2_head.py
ppdet/modeling/anchor_heads/solov2_head.py
+531
-0
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/solov2.py
ppdet/modeling/architectures/solov2.py
+187
-0
ppdet/modeling/mask_head/__init__.py
ppdet/modeling/mask_head/__init__.py
+19
-0
ppdet/modeling/mask_head/solo_mask_head.py
ppdet/modeling/mask_head/solo_mask_head.py
+152
-0
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+82
-5
ppdet/utils/coco_eval.py
ppdet/utils/coco_eval.py
+78
-0
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+31
-2
tools/eval.py
tools/eval.py
+0
-4
tools/export_model.py
tools/export_model.py
+2
-0
未找到文件。
configs/solov2/README.md
0 → 100644
浏览文件 @
fbe1d120
# SOLOv2 (Segmenting Objects by Locations) for instance segmentation
## Introduction
-
SOLOv2 is a fast instance segmentation framework with strong performance:
[
https://arxiv.org/abs/2003.10152
](
https://arxiv.org/abs/2003.10152
)
```
@misc{wang2020solov2,
title={SOLOv2: Dynamic, Faster and Stronger},
author={Xinlong Wang and Rufeng Zhang and Tao Kong and Lei Li and Chunhua Shen},
year={2020},
eprint={2003.10152},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Model Zoo
| Backbone | Multi-scale training | Lr schd | Inf time (fps) | Mask AP | Download | Configs |
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| R50-FPN | False | 1x | - | 34.7 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml
)
|
configs/solov2/solov2_r50_fpn_1x.yml
0 → 100644
浏览文件 @
fbe1d120
architecture
:
SOLOv2
use_gpu
:
true
max_iters
:
90000
snapshot_iter
:
10000
log_smooth_window
:
20
save_dir
:
output
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric
:
COCO
weights
:
output/solov2_r50_fpn_1x/model_final
num_classes
:
81
SOLOv2
:
backbone
:
ResNet
fpn
:
FPN
bbox_head
:
SOLOv2Head
mask_head
:
SOLOv2MaskHead
batch_size
:
2
ResNet
:
depth
:
50
feature_maps
:
[
2
,
3
,
4
,
5
]
freeze_at
:
2
norm_type
:
bn
FPN
:
max_level
:
6
min_level
:
2
num_chan
:
256
spatial_scale
:
[
0.03125
,
0.0625
,
0.125
,
0.25
]
reverse_out
:
True
SOLOv2Head
:
seg_feat_channels
:
512
stacked_convs
:
4
num_grids
:
[
40
,
36
,
24
,
16
,
12
]
kernel_out_channels
:
256
SOLOv2MaskHead
:
out_channels
:
128
start_level
:
0
end_level
:
3
num_classes
:
256
LearningRate
:
base_lr
:
0.01
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
60000
,
80000
]
-
!LinearWarmup
start_factor
:
0.
steps
:
1000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
_READER_
:
'
solov2_reader.yml'
configs/solov2/solov2_reader.yml
0 → 100644
浏览文件 @
fbe1d120
TrainReader
:
batch_size
:
2
worker_num
:
2
inputs_def
:
fields
:
[
'
image'
,
'
im_id'
,
'
gt_segm'
]
dataset
:
!COCODataSet
dataset_dir
:
dataset/coco
anno_path
:
annotations/instances_train2017.json
image_dir
:
train2017
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!Poly2Mask
{}
-
!ResizeImage
target_size
:
800
max_size
:
1333
interp
:
1
use_cv2
:
true
resize_box
:
true
-
!RandomFlipImage
prob
:
0.5
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!Permute
to_bgr
:
false
channel_first
:
true
batch_transforms
:
-
!PadBatch
pad_to_stride
:
32
-
!Gt2Solov2Target
num_grids
:
[
40
,
36
,
24
,
16
,
12
]
scale_ranges
:
[[
1
,
96
],
[
48
,
192
],
[
96
,
384
],
[
192
,
768
],
[
384
,
2048
]]
coord_sigma
:
0.2
shuffle
:
True
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!ResizeImage
interp
:
1
max_size
:
1333
target_size
:
800
use_cv2
:
true
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!Permute
channel_first
:
true
to_bgr
:
false
batch_transforms
:
-
!PadBatch
pad_to_stride
:
32
use_padded_im_info
:
false
batch_size
:
1
shuffle
:
false
drop_last
:
false
drop_empty
:
false
worker_num
:
2
TestReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
im_shape'
]
dataset
:
!ImageFolder
anno_path
:
dataset/coco/annotations/instances_val2017.json
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!ResizeImage
interp
:
1
max_size
:
1333
target_size
:
800
use_cv2
:
true
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!Permute
channel_first
:
true
to_bgr
:
false
batch_transforms
:
-
!PadBatch
pad_to_stride
:
32
use_padded_im_info
:
false
deploy/python/infer.py
浏览文件 @
fbe1d120
...
...
@@ -30,6 +30,7 @@ RESIZE_SCALE_SET = {
'RCNN'
,
'RetinaNet'
,
'FCOS'
,
'SOLOv2'
,
}
SUPPORT_MODELS
=
{
...
...
@@ -41,6 +42,7 @@ SUPPORT_MODELS = {
'Face'
,
'TTF'
,
'FCOS'
,
'SOLOv2'
,
}
...
...
@@ -85,7 +87,8 @@ class Resize(object):
max_size
,
use_cv2
=
True
,
image_shape
=
None
,
interp
=
cv2
.
INTER_LINEAR
):
interp
=
cv2
.
INTER_LINEAR
,
resize_box
=
False
):
self
.
target_size
=
target_size
self
.
max_size
=
max_size
self
.
image_shape
=
image_shape
...
...
@@ -251,7 +254,7 @@ class PadStride(object):
pad_w
=
int
(
np
.
ceil
(
float
(
im_w
)
/
coarsest_stride
)
*
coarsest_stride
)
padding_im
=
np
.
zeros
((
im_c
,
pad_h
,
pad_w
),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
im
im_info
[
'
resize
_shape'
]
=
padding_im
.
shape
[
1
:]
im_info
[
'
pad
_shape'
]
=
padding_im
.
shape
[
1
:]
return
padding_im
,
im_info
...
...
@@ -268,23 +271,29 @@ def create_inputs(im, im_info, model_arch='YOLO'):
inputs
[
'image'
]
=
im
origin_shape
=
list
(
im_info
[
'origin_shape'
])
resize_shape
=
list
(
im_info
[
'resize_shape'
])
pad_shape
=
list
(
im_info
[
'pad_shape'
])
if
'pad_shape'
in
im_info
else
list
(
im_info
[
'resize_shape'
])
scale_x
,
scale_y
=
im_info
[
'scale'
]
if
'YOLO'
in
model_arch
:
im_size
=
np
.
array
([
origin_shape
]).
astype
(
'int32'
)
inputs
[
'im_size'
]
=
im_size
elif
'RetinaNet'
or
'EfficientDet'
in
model_arch
:
elif
'RetinaNet'
in
model_arch
or
'EfficientDet'
in
model_arch
:
scale
=
scale_x
im_info
=
np
.
array
([
resize
_shape
+
[
scale
]]).
astype
(
'float32'
)
im_info
=
np
.
array
([
pad
_shape
+
[
scale
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
elif
(
'RCNN'
in
model_arch
)
or
(
'FCOS'
in
model_arch
):
scale
=
scale_x
im_info
=
np
.
array
([
resize
_shape
+
[
scale
]]).
astype
(
'float32'
)
im_info
=
np
.
array
([
pad
_shape
+
[
scale
]]).
astype
(
'float32'
)
im_shape
=
np
.
array
([
origin_shape
+
[
1.
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
inputs
[
'im_shape'
]
=
im_shape
elif
'TTF'
in
model_arch
:
scale_factor
=
np
.
array
([
scale_x
,
scale_y
]
*
2
).
astype
(
'float32'
)
inputs
[
'scale_factor'
]
=
scale_factor
elif
'SOLOv2'
in
model_arch
:
scale
=
scale_x
im_info
=
np
.
array
([
resize_shape
+
[
scale
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
return
inputs
...
...
@@ -405,10 +414,15 @@ def visualize(image_file,
results
,
labels
,
mask_resolution
=
14
,
output_dir
=
'output/'
):
output_dir
=
'output/'
,
threshold
=
0.5
):
# visualize the predict result
im
=
visualize_box_mask
(
image_file
,
results
,
labels
,
mask_resolution
=
mask_resolution
)
image_file
,
results
,
labels
,
mask_resolution
=
mask_resolution
,
threshold
=
threshold
)
img_name
=
os
.
path
.
split
(
image_file
)[
-
1
]
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
...
...
@@ -516,6 +530,11 @@ class Detector():
ms
=
(
t2
-
t1
)
*
1000.0
/
repeats
print
(
"Inference: {} ms per batch image"
.
format
(
ms
))
if
self
.
config
.
arch
==
'SOLOv2'
:
return
dict
(
segm
=
np
.
array
(
outs
[
2
]),
label
=
np
.
array
(
outs
[
0
]),
score
=
np
.
array
(
outs
[
1
]))
np_boxes
=
np
.
array
(
outs
[
0
])
if
self
.
config
.
mask_resolution
is
not
None
:
np_masks
=
np
.
array
(
outs
[
1
])
...
...
@@ -539,6 +558,13 @@ class Detector():
for
i
in
range
(
repeats
):
self
.
predictor
.
zero_copy_run
()
output_names
=
self
.
predictor
.
get_output_names
()
if
self
.
config
.
arch
==
'SOLOv2'
:
np_label
=
self
.
predictor
.
get_output_tensor
(
output_names
[
0
]).
copy_to_cpu
()
np_score
=
self
.
predictor
.
get_output_tensor
(
output_names
[
1
]).
copy_to_cpu
()
np_segms
=
self
.
predictor
.
get_output_tensor
(
output_names
[
2
]).
copy_to_cpu
()
boxes_tensor
=
self
.
predictor
.
get_output_tensor
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
if
self
.
config
.
mask_resolution
is
not
None
:
...
...
@@ -552,6 +578,9 @@ class Detector():
# do not perform postprocess in benchmark mode
results
=
[]
if
not
run_benchmark
:
if
self
.
config
.
arch
==
'SOLOv2'
:
return
dict
(
segm
=
np_segms
,
label
=
np_label
,
score
=
np_score
)
if
reduce
(
lambda
x
,
y
:
x
*
y
,
np_boxes
.
shape
)
<
6
:
print
(
'[WARNNING] No object detected.'
)
results
=
{
'boxes'
:
np
.
array
([])}
...
...
@@ -579,7 +608,8 @@ def predict_image():
results
,
detector
.
config
.
labels
,
mask_resolution
=
detector
.
config
.
mask_resolution
,
output_dir
=
FLAGS
.
output_dir
)
output_dir
=
FLAGS
.
output_dir
,
threshold
=
FLAGS
.
threshold
)
def
predict_video
(
camera_id
):
...
...
deploy/python/visualize.py
浏览文件 @
fbe1d120
...
...
@@ -18,20 +18,22 @@ from __future__ import division
import
cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
from
scipy
import
ndimage
def
visualize_box_mask
(
im
,
results
,
labels
,
mask_resolution
=
14
):
"""
def
visualize_box_mask
(
im
,
results
,
labels
,
mask_resolution
=
14
,
threshold
=
0.5
):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
,
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
if
isinstance
(
im
,
str
):
im
=
Image
.
open
(
im
).
convert
(
'RGB'
)
...
...
@@ -46,15 +48,23 @@ def visualize_box_mask(im, results, labels, mask_resolution=14):
resolution
=
mask_resolution
)
if
'boxes'
in
results
:
im
=
draw_box
(
im
,
results
[
'boxes'
],
labels
)
if
'segm'
in
results
:
im
=
draw_segm
(
im
,
results
[
'segm'
],
results
[
'label'
],
results
[
'score'
],
labels
,
threshold
=
threshold
)
return
im
def
get_color_map_list
(
num_classes
):
"""
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
color_map (list): RGB color list
"""
color_map
=
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
num_classes
):
...
...
@@ -71,9 +81,9 @@ def get_color_map_list(num_classes):
def
expand_boxes
(
boxes
,
scale
=
0.0
):
"""
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box
,
boxes (np.ndarray): shape:[N,4], N:number of box
,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
...
...
@@ -94,17 +104,17 @@ def expand_boxes(boxes, scale=0.0):
def
draw_mask
(
im
,
np_boxes
,
np_masks
,
labels
,
resolution
=
14
,
threshold
=
0.5
):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
labels (list): labels:['class1', ..., 'classn']
resolution (int): shape of a mask is:[resolution, resolution]
threshold (float): threshold of mask
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
color_list
=
get_color_map_list
(
len
(
labels
))
scale
=
(
resolution
+
2.0
)
/
resolution
...
...
@@ -149,14 +159,14 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
def
draw_box
(
im
,
np_boxes
,
labels
):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
np_boxes (np.ndarray): shape:[N,6], N: number of box
,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
draw_thickness
=
min
(
im
.
size
)
//
320
draw
=
ImageDraw
.
Draw
(
im
)
...
...
@@ -186,3 +196,43 @@ def draw_box(im, np_boxes, labels):
[(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
text
((
xmin
+
1
,
ymin
-
th
),
text
,
fill
=
(
255
,
255
,
255
))
return
im
def
draw_segm
(
im
,
np_segms
,
np_label
,
np_score
,
labels
,
threshold
=
0.5
,
alpha
=
0.7
):
"""
Draw segmentation on image
"""
mask_color_id
=
0
w_ratio
=
.
4
color_list
=
get_color_map_list
(
len
(
labels
))
im
=
np
.
array
(
im
).
astype
(
'float32'
)
clsid2color
=
{}
np_segms
=
np_segms
.
astype
(
np
.
uint8
)
for
i
in
range
(
np_segms
.
shape
[
0
]):
mask
,
score
,
clsid
=
np_segms
[
i
],
np_score
[
i
],
np_label
[
i
]
+
1
if
score
<
threshold
:
continue
if
clsid
not
in
clsid2color
:
clsid2color
[
clsid
]
=
color_list
[
clsid
]
color_mask
=
clsid2color
[
clsid
]
for
c
in
range
(
3
):
color_mask
[
c
]
=
color_mask
[
c
]
*
(
1
-
w_ratio
)
+
w_ratio
*
255
idx
=
np
.
nonzero
(
mask
)
color_mask
=
np
.
array
(
color_mask
)
im
[
idx
[
0
],
idx
[
1
],
:]
*=
1.0
-
alpha
im
[
idx
[
0
],
idx
[
1
],
:]
+=
alpha
*
color_mask
center_y
,
center_x
=
ndimage
.
measurements
.
center_of_mass
(
mask
)
label_text
=
"{}"
.
format
(
labels
[
clsid
])
print
(
label_text
)
print
(
center_y
,
center_x
)
vis_pos
=
(
max
(
int
(
center_x
)
-
10
,
0
),
int
(
center_y
))
cv2
.
putText
(
im
,
label_text
,
vis_pos
,
cv2
.
FONT_HERSHEY_COMPLEX
,
0.3
,
(
255
,
255
,
255
))
return
Image
.
fromarray
(
im
.
astype
(
'uint8'
))
ppdet/data/transform/batch_operators.py
浏览文件 @
fbe1d120
...
...
@@ -24,6 +24,7 @@ except Exception:
import
logging
import
cv2
import
numpy
as
np
from
scipy
import
ndimage
from
.operators
import
register_op
,
BaseOperator
from
.op_helper
import
jaccard_overlap
,
gaussian2D
...
...
@@ -37,6 +38,7 @@ __all__ = [
'Gt2YoloTarget'
,
'Gt2FCOSTarget'
,
'Gt2TTFTarget'
,
'Gt2Solov2Target'
,
]
...
...
@@ -88,6 +90,13 @@ class PadBatch(BaseOperator):
(
1
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_sem
[:,
:
im_h
,
:
im_w
]
=
semantic
data
[
'semantic'
]
=
padding_sem
if
'gt_segm'
in
data
.
keys
()
and
data
[
'gt_segm'
]
is
not
None
:
gt_segm
=
data
[
'gt_segm'
]
padding_segm
=
np
.
zeros
(
(
gt_segm
.
shape
[
0
],
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
uint8
)
padding_segm
[:,
:
im_h
,
:
im_w
]
=
gt_segm
data
[
'gt_segm'
]
=
padding_segm
return
samples
...
...
@@ -590,3 +599,154 @@ class Gt2TTFTarget(BaseOperator):
heatmap
[
y
-
top
:
y
+
bottom
,
x
-
left
:
x
+
right
]
=
np
.
maximum
(
masked_heatmap
,
masked_gaussian
)
return
heatmap
@
register_op
class
Gt2Solov2Target
(
BaseOperator
):
"""Assign mask target and labels in SOLOv2 network.
Args:
num_grids (list): The list of feature map grids size.
scale_ranges (list): The list of mask boundary range.
coord_sigma (float): The coefficient of coordinate area length.
sampling_ratio (float): The ratio of down sampling.
"""
def
__init__
(
self
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
scale_ranges
=
[[
1
,
96
],
[
48
,
192
],
[
96
,
384
],
[
192
,
768
],
[
384
,
2048
]],
coord_sigma
=
0.2
,
sampling_ratio
=
4.0
):
super
(
Gt2Solov2Target
,
self
).
__init__
()
self
.
num_grids
=
num_grids
self
.
scale_ranges
=
scale_ranges
self
.
coord_sigma
=
coord_sigma
self
.
sampling_ratio
=
sampling_ratio
def
_scale_size
(
self
,
im
,
scale
):
h
,
w
=
im
.
shape
[:
2
]
new_size
=
(
int
(
w
*
float
(
scale
)
+
0.5
),
int
(
h
*
float
(
scale
)
+
0.5
))
resized_img
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
scale
,
fy
=
scale
,
interpolation
=
cv2
.
INTER_LINEAR
)
return
resized_img
def
__call__
(
self
,
samples
,
context
=
None
):
for
sample
in
samples
:
gt_bboxes_raw
=
sample
[
'gt_bbox'
]
gt_labels_raw
=
sample
[
'gt_class'
]
im_c
,
im_h
,
im_w
=
sample
[
'image'
].
shape
[:]
gt_masks_raw
=
sample
[
'gt_segm'
].
astype
(
np
.
uint8
)
mask_feat_size
=
[
int
(
im_h
/
self
.
sampling_ratio
),
int
(
im_w
/
self
.
sampling_ratio
)
]
gt_areas
=
np
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_ind_label_list
=
[]
grid_offset
=
[]
idx
=
0
for
(
lower_bound
,
upper_bound
),
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
num_grids
):
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
()[
0
]
num_ins
=
len
(
hit_indices
)
ins_label
=
[]
grid_order
=
[]
cate_label
=
np
.
zeros
([
num_grid
,
num_grid
],
dtype
=
np
.
int64
)
ins_ind_label
=
np
.
zeros
([
num_grid
**
2
],
dtype
=
np
.
bool
)
if
num_ins
==
0
:
ins_label
=
np
.
zeros
(
[
1
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
np
.
uint8
)
ins_ind_label_list
.
append
(
ins_ind_label
)
sample
[
'cate_label{}'
.
format
(
idx
)]
=
cate_label
.
flatten
()
sample
[
'ins_label{}'
.
format
(
idx
)]
=
ins_label
sample
[
'grid_order{}'
.
format
(
idx
)]
=
np
.
asarray
([
0
])
grid_offset
.
append
(
1
)
idx
+=
1
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
,
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
coord_sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
coord_sigma
for
seg_mask
,
gt_label
,
half_h
,
half_w
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
):
if
seg_mask
.
sum
()
==
0
:
continue
# mass center
upsampled_size
=
(
mask_feat_size
[
0
]
*
4
,
mask_feat_size
[
1
]
*
4
)
center_h
,
center_w
=
ndimage
.
measurements
.
center_of_mass
(
seg_mask
)
coord_w
=
int
(
(
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
(
(
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
seg_mask
=
self
.
_scale_size
(
seg_mask
,
scale
=
1.
/
self
.
sampling_ratio
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
cur_ins_label
=
np
.
zeros
(
[
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
np
.
uint8
)
cur_ins_label
[:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_label
.
append
(
cur_ins_label
)
ins_ind_label
[
label
]
=
True
grid_order
.
append
(
label
)
if
ins_label
==
[]:
ins_label
=
np
.
zeros
(
[
1
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
np
.
uint8
)
ins_ind_label_list
.
append
(
ins_ind_label
)
sample
[
'cate_label{}'
.
format
(
idx
)]
=
cate_label
.
flatten
()
sample
[
'ins_label{}'
.
format
(
idx
)]
=
ins_label
sample
[
'grid_order{}'
.
format
(
idx
)]
=
np
.
asarray
([
0
])
grid_offset
.
append
(
1
)
else
:
ins_label
=
np
.
stack
(
ins_label
,
axis
=
0
)
ins_ind_label_list
.
append
(
ins_ind_label
)
sample
[
'cate_label{}'
.
format
(
idx
)]
=
cate_label
.
flatten
()
sample
[
'ins_label{}'
.
format
(
idx
)]
=
ins_label
sample
[
'grid_order{}'
.
format
(
idx
)]
=
np
.
asarray
(
grid_order
)
assert
len
(
grid_order
)
>
0
grid_offset
.
append
(
len
(
grid_order
))
idx
+=
1
ins_ind_labels
=
np
.
concatenate
([
ins_ind_labels_level_img
for
ins_ind_labels_level_img
in
ins_ind_label_list
])
fg_num
=
np
.
sum
(
ins_ind_labels
)
sample
[
'fg_num'
]
=
fg_num
sample
[
'grid_offset'
]
=
np
.
asarray
(
grid_offset
).
astype
(
np
.
int32
)
return
samples
ppdet/data/transform/operators.py
浏览文件 @
fbe1d120
...
...
@@ -272,7 +272,8 @@ class ResizeImage(BaseOperator):
target_size
=
0
,
max_size
=
0
,
interp
=
cv2
.
INTER_LINEAR
,
use_cv2
=
True
):
use_cv2
=
True
,
resize_box
=
False
):
"""
Rescale image to the specified target size, and capped at max_size
if max_size != 0.
...
...
@@ -285,11 +286,13 @@ class ResizeImage(BaseOperator):
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
resize_box (bool): whether resize ground truth bbox annotations.
"""
super
(
ResizeImage
,
self
).
__init__
()
self
.
max_size
=
int
(
max_size
)
self
.
interp
=
int
(
interp
)
self
.
use_cv2
=
use_cv2
self
.
resize_box
=
resize_box
if
not
(
isinstance
(
target_size
,
int
)
or
isinstance
(
target_size
,
list
)):
raise
TypeError
(
"Type of target_size is invalid. Must be Integer or List, now is {}"
.
...
...
@@ -348,18 +351,6 @@ class ResizeImage(BaseOperator):
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
if
'semantic'
in
sample
.
keys
()
and
sample
[
'semantic'
]
is
not
None
:
semantic
=
sample
[
'semantic'
]
semantic
=
cv2
.
resize
(
semantic
.
astype
(
'float32'
),
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
semantic
=
np
.
asarray
(
semantic
).
astype
(
'int32'
)
semantic
=
np
.
expand_dims
(
semantic
,
0
)
sample
[
'semantic'
]
=
semantic
else
:
if
self
.
max_size
!=
0
:
raise
TypeError
(
...
...
@@ -370,6 +361,38 @@ class ResizeImage(BaseOperator):
im
=
im
.
resize
((
int
(
resize_w
),
int
(
resize_h
)),
self
.
interp
)
im
=
np
.
array
(
im
)
sample
[
'image'
]
=
im
sample
[
'scale_factor'
]
=
[
im_scale_x
,
im_scale_y
]
*
2
if
'gt_bbox'
in
sample
and
self
.
resize_box
and
len
(
sample
[
'gt_bbox'
])
>
0
:
bboxes
=
sample
[
'gt_bbox'
]
*
sample
[
'scale_factor'
]
bboxes
[:,
0
::
2
]
=
np
.
clip
(
bboxes
[:,
0
::
2
],
0
,
resize_w
-
1
)
bboxes
[:,
1
::
2
]
=
np
.
clip
(
bboxes
[:,
1
::
2
],
0
,
resize_h
-
1
)
sample
[
'gt_bbox'
]
=
bboxes
if
'semantic'
in
sample
.
keys
()
and
sample
[
'semantic'
]
is
not
None
:
semantic
=
sample
[
'semantic'
]
semantic
=
cv2
.
resize
(
semantic
.
astype
(
'float32'
),
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
semantic
=
np
.
asarray
(
semantic
).
astype
(
'int32'
)
semantic
=
np
.
expand_dims
(
semantic
,
0
)
sample
[
'semantic'
]
=
semantic
if
'gt_segm'
in
sample
and
len
(
sample
[
'gt_segm'
])
>
0
:
masks
=
[
cv2
.
resize
(
gt_segm
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_NEAREST
)
for
gt_segm
in
sample
[
'gt_segm'
]
]
sample
[
'gt_segm'
]
=
np
.
asarray
(
masks
).
astype
(
np
.
uint8
)
return
sample
...
...
@@ -473,7 +496,6 @@ class RandomFlipImage(BaseOperator):
if
self
.
is_mask_flip
and
len
(
sample
[
'gt_poly'
])
!=
0
:
sample
[
'gt_poly'
]
=
self
.
flip_segms
(
sample
[
'gt_poly'
],
height
,
width
)
if
'gt_keypoint'
in
sample
.
keys
():
sample
[
'gt_keypoint'
]
=
self
.
flip_keypoint
(
sample
[
'gt_keypoint'
],
width
)
...
...
@@ -482,6 +504,9 @@ class RandomFlipImage(BaseOperator):
'semantic'
]
is
not
None
:
sample
[
'semantic'
]
=
sample
[
'semantic'
][:,
::
-
1
]
if
'gt_segm'
in
sample
.
keys
()
and
sample
[
'gt_segm'
]
is
not
None
:
sample
[
'gt_segm'
]
=
sample
[
'gt_segm'
][:,
:,
::
-
1
]
sample
[
'flipped'
]
=
True
sample
[
'image'
]
=
im
sample
=
samples
if
batch_input
else
samples
[
0
]
...
...
@@ -2557,3 +2582,41 @@ class DebugVisibleImage(BaseOperator):
save_path
=
os
.
path
.
join
(
self
.
output_dir
,
out_file_name
)
image
.
save
(
save_path
,
quality
=
95
)
return
sample
@
register_op
class
Poly2Mask
(
BaseOperator
):
"""
gt poly to mask annotations
"""
def
__init__
(
self
):
super
(
Poly2Mask
,
self
).
__init__
()
import
pycocotools.mask
as
maskUtils
self
.
maskutils
=
maskUtils
def
_poly2mask
(
self
,
mask_ann
,
img_h
,
img_w
):
if
isinstance
(
mask_ann
,
list
):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles
=
self
.
maskutils
.
frPyObjects
(
mask_ann
,
img_h
,
img_w
)
rle
=
self
.
maskutils
.
merge
(
rles
)
elif
isinstance
(
mask_ann
[
'counts'
],
list
):
# uncompressed RLE
rle
=
self
.
maskutils
.
frPyObjects
(
mask_ann
,
img_h
,
img_w
)
else
:
# rle
rle
=
mask_ann
mask
=
self
.
maskutils
.
decode
(
rle
)
return
mask
def
__call__
(
self
,
sample
,
context
=
None
):
assert
'gt_poly'
in
sample
im_h
=
sample
[
'h'
]
im_w
=
sample
[
'w'
]
masks
=
[
self
.
_poly2mask
(
gt_poly
,
im_h
,
im_w
)
for
gt_poly
in
sample
[
'gt_poly'
]
]
sample
[
'gt_segm'
]
=
np
.
asarray
(
masks
).
astype
(
np
.
uint8
)
return
sample
ppdet/modeling/__init__.py
浏览文件 @
fbe1d120
...
...
@@ -22,6 +22,7 @@ from . import roi_extractors
from
.
import
roi_heads
from
.
import
ops
from
.
import
target_assigners
from
.
import
mask_head
from
.anchor_heads
import
*
from
.architectures
import
*
...
...
@@ -30,3 +31,4 @@ from .roi_extractors import *
from
.roi_heads
import
*
from
.ops
import
*
from
.target_assigners
import
*
from
.mask_head
import
*
ppdet/modeling/anchor_heads/__init__.py
浏览文件 @
fbe1d120
...
...
@@ -21,6 +21,7 @@ from . import fcos_head
from
.
import
corner_head
from
.
import
efficient_head
from
.
import
ttf_head
from
.
import
solov2_head
from
.rpn_head
import
*
from
.yolo_head
import
*
...
...
@@ -29,3 +30,4 @@ from .fcos_head import *
from
.corner_head
import
*
from
.efficient_head
import
*
from
.ttf_head
import
*
from
.solov2_head
import
*
ppdet/modeling/anchor_heads/solov2_head.py
0 → 100644
浏览文件 @
fbe1d120
此差异已折叠。
点击以展开。
ppdet/modeling/architectures/__init__.py
浏览文件 @
fbe1d120
...
...
@@ -29,6 +29,7 @@ from . import fcos
from
.
import
cornernet_squeeze
from
.
import
ttfnet
from
.
import
htc
from
.
import
solov2
from
.faster_rcnn
import
*
from
.mask_rcnn
import
*
...
...
@@ -45,3 +46,4 @@ from .fcos import *
from
.cornernet_squeeze
import
*
from
.ttfnet
import
*
from
.htc
import
*
from
.solov2
import
*
ppdet/modeling/architectures/solov2.py
0 → 100644
浏览文件 @
fbe1d120
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
from
paddle
import
fluid
from
ppdet.experimental
import
mixed_precision_global_state
from
ppdet.core.workspace
import
register
__all__
=
[
'SOLOv2'
]
@
register
class
SOLOv2
(
object
):
"""
SOLOv2 network, see https://arxiv.org/abs/2003.10152
Args:
backbone (object): an backbone instance
fpn (object): feature pyramid network instance
bbox_head (object): an `SOLOv2Head` instance
mask_head (object): an `SOLOv2MaskHead` instance
batch_size (int): batch size.
"""
__category__
=
'architecture'
__inject__
=
[
'backbone'
,
'fpn'
,
'bbox_head'
,
'mask_head'
]
def
__init__
(
self
,
backbone
,
fpn
=
None
,
bbox_head
=
'SOLOv2Head'
,
mask_head
=
'SOLOv2MaskHead'
,
batch_size
=
1
):
super
(
SOLOv2
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
fpn
=
fpn
self
.
bbox_head
=
bbox_head
self
.
mask_head
=
mask_head
self
.
batch_size
=
batch_size
def
build
(
self
,
feed_vars
,
mode
=
'train'
):
im
=
feed_vars
[
'image'
]
mixed_precision_enabled
=
mixed_precision_global_state
()
is
not
None
# cast inputs to FP16
if
mixed_precision_enabled
:
im
=
fluid
.
layers
.
cast
(
im
,
'float16'
)
body_feats
=
self
.
backbone
(
im
)
if
self
.
fpn
is
not
None
:
body_feats
,
spatial_scale
=
self
.
fpn
.
get_output
(
body_feats
)
if
isinstance
(
body_feats
,
OrderedDict
):
body_feat_names
=
list
(
body_feats
.
keys
())
body_feats
=
[
body_feats
[
name
]
for
name
in
body_feat_names
]
# cast features back to FP32
if
mixed_precision_enabled
:
body_feats
=
[
fluid
.
layers
.
cast
(
v
,
'float32'
)
for
v
in
body_feats
]
if
not
mode
==
'train'
:
self
.
batch_size
=
1
mask_feat_pred
=
self
.
mask_head
.
get_output
(
body_feats
,
self
.
batch_size
)
if
mode
==
'train'
:
ins_labels
=
[]
cate_labels
=
[]
grid_orders
=
[]
fg_num
=
feed_vars
[
'fg_num'
]
grid_offset
=
feed_vars
[
'grid_offset'
]
for
i
in
range
(
5
):
ins_label
=
'ins_label{}'
.
format
(
i
)
if
ins_label
in
feed_vars
:
ins_labels
.
append
(
feed_vars
[
ins_label
])
cate_label
=
'cate_label{}'
.
format
(
i
)
if
cate_label
in
feed_vars
:
cate_labels
.
append
(
feed_vars
[
cate_label
])
grid_order
=
'grid_order{}'
.
format
(
i
)
if
grid_order
in
feed_vars
:
grid_orders
.
append
(
feed_vars
[
grid_order
])
cate_preds
,
kernel_preds
=
self
.
bbox_head
.
get_outputs
(
body_feats
,
batch_size
=
self
.
batch_size
)
losses
=
self
.
bbox_head
.
get_loss
(
cate_preds
,
kernel_preds
,
mask_feat_pred
,
ins_labels
,
cate_labels
,
grid_orders
,
fg_num
,
grid_offset
,
self
.
batch_size
)
total_loss
=
fluid
.
layers
.
sum
(
list
(
losses
.
values
()))
losses
.
update
({
'loss'
:
total_loss
})
return
losses
else
:
im_info
=
feed_vars
[
'im_info'
]
outs
=
self
.
bbox_head
.
get_outputs
(
body_feats
,
is_eval
=
True
,
batch_size
=
self
.
batch_size
)
seg_inputs
=
outs
+
(
mask_feat_pred
,
im_info
)
return
self
.
bbox_head
.
get_prediction
(
*
seg_inputs
)
def
_inputs_def
(
self
,
image_shape
,
fields
):
im_shape
=
[
None
]
+
image_shape
# yapf: disable
inputs_def
=
{
'image'
:
{
'shape'
:
im_shape
,
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'im_info'
:
{
'shape'
:
[
None
,
3
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'im_id'
:
{
'shape'
:
[
None
,
1
],
'dtype'
:
'int64'
,
'lod_level'
:
0
},
'im_shape'
:
{
'shape'
:
[
None
,
3
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
}
if
'gt_segm'
in
fields
:
targets_def
=
{
'ins_label0'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label1'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label2'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label3'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'ins_label4'
:
{
'shape'
:
[
None
,
None
,
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label0'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label1'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label2'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label3'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'cate_label4'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order0'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order1'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order2'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order3'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'grid_order4'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
1
},
'fg_num'
:
{
'shape'
:
[
None
],
'dtype'
:
'int32'
,
'lod_level'
:
0
},
'grid_offset'
:
{
'shape'
:
[
None
,
5
],
'dtype'
:
'int32'
,
'lod_level'
:
0
},
}
# yapf: enable
inputs_def
.
update
(
targets_def
)
return
inputs_def
def
build_inputs
(
self
,
image_shape
=
[
3
,
None
,
None
],
fields
=
[
'image'
,
'im_id'
,
'gt_segm'
],
# for train
use_dataloader
=
True
,
iterable
=
False
):
inputs_def
=
self
.
_inputs_def
(
image_shape
,
fields
)
if
'gt_segm'
in
fields
:
fields
.
remove
(
'gt_segm'
)
fields
.
extend
([
'fg_num'
,
'grid_offset'
])
for
i
in
range
(
5
):
fields
.
extend
([
'ins_label%d'
%
i
,
'cate_label%d'
%
i
,
'grid_order%d'
%
i
])
feed_vars
=
OrderedDict
([(
key
,
fluid
.
data
(
name
=
key
,
shape
=
inputs_def
[
key
][
'shape'
],
dtype
=
inputs_def
[
key
][
'dtype'
],
lod_level
=
inputs_def
[
key
][
'lod_level'
]))
for
key
in
fields
])
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
list
(
feed_vars
.
values
()),
capacity
=
16
,
use_double_buffer
=
True
,
iterable
=
iterable
)
if
use_dataloader
else
None
return
feed_vars
,
loader
def
train
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'train'
)
def
eval
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
def
test
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
ppdet/modeling/mask_head/__init__.py
0 → 100644
浏览文件 @
fbe1d120
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
.
import
solo_mask_head
from
.solo_mask_head
import
*
ppdet/modeling/mask_head/solo_mask_head.py
0 → 100644
浏览文件 @
fbe1d120
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
fluid
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
ConvNorm
,
DeformConvNorm
__all__
=
[
'SOLOv2MaskHead'
]
@
register
class
SOLOv2MaskHead
(
object
):
"""
SOLOv2MaskHead
Args:
out_channels (int): The channel number of output variable.
start_level (int): The position where the input starts.
end_level (int): The position where the input ends.
num_classes (int): Number of classes in SOLOv2MaskHead output.
use_dcn_in_tower: Whether to use dcn in tower or not.
"""
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
81
,
use_dcn_in_tower
=
False
):
super
(
SOLOv2MaskHead
,
self
).
__init__
()
assert
start_level
>=
0
and
end_level
>=
start_level
self
.
out_channels
=
out_channels
self
.
start_level
=
start_level
self
.
end_level
=
end_level
self
.
num_classes
=
num_classes
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
conv_type
=
[
ConvNorm
,
DeformConvNorm
]
def
_convs_levels
(
self
,
conv_feat
,
level
,
name
=
None
):
conv_func
=
self
.
conv_type
[
0
]
if
self
.
use_dcn_in_tower
:
conv_func
=
self
.
conv_type
[
1
]
if
level
==
0
:
return
conv_func
(
input
=
conv_feat
,
num_filters
=
self
.
out_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
,
norm_groups
=
32
,
freeze_norm
=
False
,
act
=
'relu'
,
initializer
=
fluid
.
initializer
.
NormalInitializer
(
scale
=
0.01
),
norm_name
=
name
+
'.conv'
+
str
(
level
)
+
'.gn'
,
name
=
name
+
'.conv'
+
str
(
level
))
for
j
in
range
(
level
):
conv_feat
=
conv_func
(
input
=
conv_feat
,
num_filters
=
self
.
out_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
,
norm_groups
=
32
,
freeze_norm
=
False
,
act
=
'relu'
,
initializer
=
fluid
.
initializer
.
NormalInitializer
(
scale
=
0.01
),
norm_name
=
name
+
'.conv'
+
str
(
j
)
+
'.gn'
,
name
=
name
+
'.conv'
+
str
(
j
))
conv_feat
=
fluid
.
layers
.
resize_bilinear
(
conv_feat
,
scale
=
2
,
name
=
'upsample'
+
str
(
level
)
+
str
(
j
),
align_corners
=
False
,
align_mode
=
0
)
return
conv_feat
def
_conv_pred
(
self
,
conv_feat
):
conv_func
=
self
.
conv_type
[
0
]
if
self
.
use_dcn_in_tower
:
conv_func
=
self
.
conv_type
[
1
]
conv_feat
=
conv_func
(
input
=
conv_feat
,
num_filters
=
self
.
num_classes
,
filter_size
=
1
,
stride
=
1
,
norm_type
=
'gn'
,
norm_groups
=
32
,
freeze_norm
=
False
,
act
=
'relu'
,
initializer
=
fluid
.
initializer
.
NormalInitializer
(
scale
=
0.01
),
norm_name
=
'mask_feat_head.conv_pred.0.gn'
,
name
=
'mask_feat_head.conv_pred.0'
)
return
conv_feat
def
get_output
(
self
,
inputs
,
batch_size
=
1
):
"""
Get SOLOv2MaskHead output.
Args:
inputs(list[Variable]): feature map from each necks with shape of [N, C, H, W]
batch_size (int): batch size
Returns:
ins_pred(Variable): Output of SOLOv2MaskHead head
"""
range_level
=
self
.
end_level
-
self
.
start_level
+
1
feature_add_all_level
=
self
.
_convs_levels
(
inputs
[
0
],
0
,
name
=
'mask_feat_head.convs_all_levels.0'
)
for
i
in
range
(
1
,
range_level
):
input_p
=
inputs
[
i
]
if
i
==
3
:
input_feat
=
input_p
x_range
=
paddle
.
linspace
(
-
1
,
1
,
fluid
.
layers
.
shape
(
input_feat
)[
-
1
],
dtype
=
'float32'
)
y_range
=
paddle
.
linspace
(
-
1
,
1
,
fluid
.
layers
.
shape
(
input_feat
)[
-
2
],
dtype
=
'float32'
)
y
,
x
=
paddle
.
tensor
.
meshgrid
([
y_range
,
x_range
])
x
=
fluid
.
layers
.
unsqueeze
(
x
,
[
0
,
1
])
y
=
fluid
.
layers
.
unsqueeze
(
y
,
[
0
,
1
])
y
=
fluid
.
layers
.
expand
(
y
,
expand_times
=
[
batch_size
,
1
,
1
,
1
])
x
=
fluid
.
layers
.
expand
(
x
,
expand_times
=
[
batch_size
,
1
,
1
,
1
])
coord_feat
=
fluid
.
layers
.
concat
([
x
,
y
],
axis
=
1
)
input_p
=
fluid
.
layers
.
concat
([
input_p
,
coord_feat
],
axis
=
1
)
feature_add_all_level
=
fluid
.
layers
.
elementwise_add
(
feature_add_all_level
,
self
.
_convs_levels
(
input_p
,
i
,
name
=
'mask_feat_head.convs_all_levels.{}'
.
format
(
i
)))
ins_pred
=
self
.
_conv_pred
(
feature_add_all_level
)
return
ins_pred
ppdet/modeling/ops.py
浏览文件 @
fbe1d120
...
...
@@ -17,6 +17,7 @@ from numbers import Integral
import
math
import
six
import
paddle
from
paddle
import
fluid
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.initializer
import
NumpyArrayInitializer
...
...
@@ -1263,27 +1264,27 @@ class LibraBBoxAssigner(object):
rois
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'rois',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
4
],
)
bbox_inside_weights
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'bbox_inside_weights',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
8
if
self
.
is_cls_agnostic
else
self
.
class_nums
*
4
],
)
bbox_outside_weights
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'bbox_outside_weights',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
8
if
self
.
is_cls_agnostic
else
self
.
class_nums
*
4
],
)
bbox_targets
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'bbox_targets',
name
=
None
,
dtype
=
'float32'
,
shape
=
[
-
1
,
8
if
self
.
is_cls_agnostic
else
self
.
class_nums
*
4
],
)
labels_int32
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
None
,
#'labels_int32',
name
=
None
,
dtype
=
'int32'
,
shape
=
[
-
1
,
1
],
)
...
...
@@ -1565,3 +1566,79 @@ class RetinaOutputDecoder(object):
self
.
nms_top_k
=
pre_nms_top_n
self
.
keep_top_k
=
detections_per_im
self
.
nms_eta
=
nms_eta
@
register
@
serializable
class
MaskMatrixNMS
(
object
):
"""
Matrix NMS for multi-class masks.
Args:
kernel (str): 'linear' or 'gaussian'
sigma (float): std in gaussian method
Input:
seg_masks (Variable): shape (n, h, w), segmentation feature maps
cate_labels (Variable): shape (n), mask labels in descending order
cate_scores (Variable): shape (n), mask scores in descending order
sum_masks (Variable): The sum of seg_masks
Returns:
Variable: cate_scores_update, tensors of shape (n)
"""
def
__init__
(
self
,
kernel
=
'gaussian'
,
sigma
=
2.0
):
super
(
MaskMatrixNMS
,
self
).
__init__
()
self
.
kernel
=
kernel
self
.
sigma
=
sigma
def
__call__
(
self
,
seg_masks
,
cate_labels
,
cate_scores
,
sum_masks
=
None
):
n_samples
=
fluid
.
layers
.
shape
(
cate_labels
)
seg_masks
=
fluid
.
layers
.
reshape
(
seg_masks
,
shape
=
(
n_samples
,
-
1
))
# inter.
inter_matrix
=
paddle
.
mm
(
seg_masks
,
fluid
.
layers
.
transpose
(
seg_masks
,
[
1
,
0
]))
# union.
sum_masks_x
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
sum_masks
,
expand_times
=
[
n_samples
]),
shape
=
[
n_samples
,
n_samples
])
# iou.
iou_matrix
=
(
inter_matrix
/
(
sum_masks_x
+
fluid
.
layers
.
transpose
(
sum_masks_x
,
[
1
,
0
])
-
inter_matrix
))
iou_matrix
=
paddle
.
tensor
.
triu
(
iou_matrix
,
diagonal
=
1
)
# label_specific matrix.
cate_labels_x
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
cate_labels
,
expand_times
=
[
n_samples
]),
shape
=
[
n_samples
,
n_samples
])
label_matrix
=
fluid
.
layers
.
cast
(
(
cate_labels_x
==
fluid
.
layers
.
transpose
(
cate_labels_x
,
[
1
,
0
])),
'float32'
)
label_matrix
=
paddle
.
tensor
.
triu
(
label_matrix
,
diagonal
=
1
)
# IoU compensation
compensate_iou
=
paddle
.
max
((
iou_matrix
*
label_matrix
),
axis
=
0
)
compensate_iou
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
compensate_iou
,
expand_times
=
[
n_samples
]),
shape
=
[
n_samples
,
n_samples
])
compensate_iou
=
fluid
.
layers
.
transpose
(
compensate_iou
,
[
1
,
0
])
# IoU decay
decay_iou
=
iou_matrix
*
label_matrix
# matrix nms
if
self
.
kernel
==
'gaussian'
:
decay_matrix
=
fluid
.
layers
.
exp
(
-
1
*
self
.
sigma
*
(
decay_iou
**
2
))
compensate_matrix
=
fluid
.
layers
.
exp
(
-
1
*
self
.
sigma
*
(
compensate_iou
**
2
))
decay_coefficient
=
paddle
.
min
(
decay_matrix
/
compensate_matrix
,
axis
=
0
)
elif
self
.
kernel
==
'linear'
:
decay_matrix
=
(
1
-
decay_iou
)
/
(
1
-
compensate_iou
)
decay_coefficient
=
paddle
.
min
(
decay_matrix
,
axis
=
0
)
else
:
raise
NotImplementedError
# update the score.
cate_scores_update
=
cate_scores
*
decay_coefficient
return
cate_scores_update
ppdet/utils/coco_eval.py
浏览文件 @
fbe1d120
...
...
@@ -164,6 +164,47 @@ def mask_eval(results,
cocoapi_eval
(
outfile
,
'segm'
,
coco_gt
=
coco_gt
)
def
segm_eval
(
results
,
anno_file
,
outfile
,
save_only
=
False
):
assert
'segm'
in
results
[
0
]
assert
outfile
.
endswith
(
'.json'
)
from
pycocotools.coco
import
COCO
coco_gt
=
COCO
(
anno_file
)
clsid2catid
=
{
i
:
v
for
i
,
v
in
enumerate
(
coco_gt
.
getCatIds
())}
segm_results
=
[]
for
t
in
results
:
im_id
=
int
(
t
[
'im_id'
][
0
][
0
])
segs
=
t
[
'segm'
]
for
mask
in
segs
:
catid
=
int
(
clsid2catid
[
mask
[
0
]])
masks
=
mask
[
1
]
mask_score
=
masks
[
1
]
segm
=
masks
[
0
]
segm
[
'counts'
]
=
segm
[
'counts'
].
decode
(
'utf8'
)
coco_res
=
{
'image_id'
:
im_id
,
'category_id'
:
catid
,
'segmentation'
:
segm
,
'score'
:
mask_score
}
segm_results
.
append
(
coco_res
)
if
len
(
segm_results
)
==
0
:
logger
.
warning
(
"The number of valid mask detected is zero.
\n
\
Please use reasonable model and check input data."
)
return
with
open
(
outfile
,
'w'
)
as
f
:
json
.
dump
(
segm_results
,
f
)
if
save_only
:
logger
.
info
(
'The mask result is saved to {} and do not '
'evaluate the mAP.'
.
format
(
outfile
))
return
map_stats
=
cocoapi_eval
(
outfile
,
'segm'
,
coco_gt
=
coco_gt
)
return
map_stats
def
cocoapi_eval
(
jsonfile
,
style
,
coco_gt
=
None
,
...
...
@@ -374,6 +415,43 @@ def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
return
segm_res
def
segm2out
(
results
,
clsid2catid
,
thresh_binarize
=
0.5
):
import
pycocotools.mask
as
mask_util
segm_res
=
[]
# for each batch
for
t
in
results
:
segms
=
t
[
'segm'
][
0
]
clsid_labels
=
t
[
'cate_label'
][
0
]
clsid_scores
=
t
[
'cate_score'
][
0
]
lengths
=
segms
.
shape
[
0
]
im_id
=
int
(
t
[
'im_id'
][
0
][
0
])
im_shape
=
t
[
'im_shape'
][
0
][
0
]
if
lengths
==
0
or
segms
is
None
:
continue
# for each sample
for
i
in
range
(
lengths
-
1
):
im_h
=
int
(
im_shape
[
0
])
im_w
=
int
(
im_shape
[
1
])
clsid
=
int
(
clsid_labels
[
i
])
catid
=
clsid2catid
[
clsid
]
score
=
clsid_scores
[
i
]
mask
=
segms
[
i
]
segm
=
mask_util
.
encode
(
np
.
array
(
mask
[:,
:,
np
.
newaxis
],
order
=
'F'
))[
0
]
segm
[
'counts'
]
=
segm
[
'counts'
].
decode
(
'utf8'
)
coco_res
=
{
'image_id'
:
im_id
,
'category_id'
:
catid
,
'segmentation'
:
segm
,
'score'
:
score
}
segm_res
.
append
(
coco_res
)
return
segm_res
def
expand_boxes
(
boxes
,
scale
):
"""
Expand an array of boxes by a given scale.
...
...
ppdet/utils/eval_utils.py
浏览文件 @
fbe1d120
...
...
@@ -94,6 +94,25 @@ def clean_res(result, keep_name_list):
return
clean_result
def
get_masks
(
result
):
import
pycocotools.mask
as
mask_util
if
result
is
None
:
return
{}
seg_pred
=
result
[
'segm'
][
0
].
astype
(
np
.
uint8
)
cate_label
=
result
[
'cate_label'
][
0
].
astype
(
np
.
int
)
cate_score
=
result
[
'cate_score'
][
0
].
astype
(
np
.
float
)
num_ins
=
seg_pred
.
shape
[
0
]
masks
=
[]
for
idx
in
range
(
num_ins
-
1
):
cur_mask
=
seg_pred
[
idx
,
...]
rle
=
mask_util
.
encode
(
np
.
array
(
cur_mask
[:,
:,
np
.
newaxis
],
order
=
'F'
))[
0
]
rst
=
(
rle
,
cate_score
[
idx
])
masks
.
append
([
cate_label
[
idx
],
rst
])
return
masks
def
eval_run
(
exe
,
compile_program
,
loader
,
...
...
@@ -163,11 +182,13 @@ def eval_run(exe,
corner_post_process
(
res
,
post_config
,
cfg
.
num_classes
)
if
'TTFNet'
in
cfg
.
architecture
:
res
[
'bbox'
][
1
].
append
([
len
(
res
[
'bbox'
][
0
])])
if
'segm'
in
res
:
res
[
'segm'
]
=
get_masks
(
res
)
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
if
len
(
res
[
'bbox'
][
1
])
==
0
:
if
'bbox'
not
in
res
or
len
(
res
[
'bbox'
][
1
])
==
0
:
has_bbox
=
False
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
except
(
StopIteration
,
fluid
.
core
.
EOFException
):
...
...
@@ -198,7 +219,7 @@ def eval_results(results,
"""Evaluation for evaluation program results"""
box_ap_stats
=
[]
if
metric
==
'COCO'
:
from
ppdet.utils.coco_eval
import
proposal_eval
,
bbox_eval
,
mask_eval
from
ppdet.utils.coco_eval
import
proposal_eval
,
bbox_eval
,
mask_eval
,
segm_eval
anno_file
=
dataset
.
get_anno
()
with_background
=
dataset
.
with_background
if
'proposal'
in
results
[
0
]:
...
...
@@ -225,6 +246,14 @@ def eval_results(results,
output
=
os
.
path
.
join
(
output_directory
,
'mask.json'
)
mask_eval
(
results
,
anno_file
,
output
,
resolution
,
save_only
=
save_only
)
if
'segm'
in
results
[
0
]:
output
=
'segm.json'
if
output_directory
:
output
=
os
.
path
.
join
(
output_directory
,
output
)
mask_ap_stats
=
segm_eval
(
results
,
anno_file
,
output
,
save_only
=
save_only
)
if
len
(
box_ap_stats
)
==
0
:
box_ap_stats
=
mask_ap_stats
else
:
if
'accum_map'
in
results
[
-
1
]:
res
=
np
.
mean
(
results
[
-
1
][
'accum_map'
][
0
])
...
...
tools/eval.py
浏览文件 @
fbe1d120
...
...
@@ -132,9 +132,6 @@ def main():
extra_keys
)
sub_eval_prog
=
sub_eval_prog
.
clone
(
True
)
#if 'weights' in cfg:
# checkpoint.load_params(exe, sub_eval_prog, cfg.weights)
# load model
exe
.
run
(
startup_prog
)
if
'weights'
in
cfg
:
...
...
@@ -146,7 +143,6 @@ def main():
results
=
eval_run
(
exe
,
compile_program
,
loader
,
keys
,
values
,
cls
,
cfg
,
sub_eval_prog
,
sub_keys
,
sub_values
,
resolution
)
#print(cfg['EvalReader']['dataset'].__dict__)
# evaluation
# if map_type not set, use default 11point, only use in VOC eval
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
...
...
tools/export_model.py
浏览文件 @
fbe1d120
...
...
@@ -46,11 +46,13 @@ TRT_MIN_SUBGRAPH = {
'Face'
:
3
,
'TTFNet'
:
3
,
'FCOS'
:
3
,
'SOLOv2'
:
3
,
}
RESIZE_SCALE_SET
=
{
'RCNN'
,
'RetinaNet'
,
'FCOS'
,
'SOLOv2'
,
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录