未验证 提交 f4618562 编写于 作者: G Guanghua Yu 提交者: GitHub

support infer for SOLOv2 (#1552)

* support infer for SOLOv2
上级 c72fea6a
...@@ -77,7 +77,7 @@ PaddleDetection模块化地实现了多种主流目标检测算法,提供了 ...@@ -77,7 +77,7 @@ PaddleDetection模块化地实现了多种主流目标检测算法,提供了
<li><b>Instance Segmentation</b></li> <li><b>Instance Segmentation</b></li>
<ul> <ul>
<li>Mask RCNN</li> <li>Mask RCNN</li>
<li>SOLOv2 is coming soon</li> <li>SOLOv2</li>
</ul> </ul>
</ul> </ul>
<ul> <ul>
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy and speed of the SOLOv2. SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy and speed of the SOLOv2.
** Highlights: ** **Highlights:**
- Performance: `Light-R50-VD-DCN-FPN` model reached 38.6 FPS on single Tesla V100, and mask ap on the COCO-val dataset reached 38.8, which increased inference speed by 24%, mAP increased by 2.4 percentage points. - Performance: `Light-R50-VD-DCN-FPN` model reached 38.6 FPS on single Tesla V100, and mask ap on the COCO-val dataset reached 38.8, which increased inference speed by 24%, mAP increased by 2.4 percentage points.
- Training Time: The training time of the model of `solov2_r50_fpn_1x` on Tesla v100 with 8 GPU is only 10 hours. - Training Time: The training time of the model of `solov2_r50_fpn_1x` on Tesla v100 with 8 GPU is only 10 hours.
...@@ -15,7 +16,7 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo ...@@ -15,7 +16,7 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: | | :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| R50-FPN | False | 1x | 45.7ms | 35.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml) | | R50-FPN | False | 1x | 45.7ms | 35.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml) |
| R50-FPN | True | 3x | 45.7ms | 37.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_3x.yml) | | R50-FPN | True | 3x | 45.7ms | 37.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_3x.yml) |
| R101-VD-FPN | True | 3x | - | 42.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r101_vd_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r101_vd_fpn_3x.yml) | | R101-VD-FPN | True | 3x | 82.6ms | 42.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r101_vd_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r101_vd_fpn_3x.yml) |
## Enhanced model ## Enhanced model
| Backbone | Input size | Lr schd | Inf time (V100) | Mask AP | Download | Configs | | Backbone | Input size | Lr schd | Inf time (V100) | Mask AP | Download | Configs |
......
architecture: SOLOv2
use_gpu: true
max_iters: 270000
snapshot_iter: 30000
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
metric: COCO
weights: output/solov2_light_r50_vd_fpn_dcn_512_3x/model_final
num_classes: 81
use_ema: true
ema_decay: 0.9998
SOLOv2:
backbone: ResNet
fpn: FPN
bbox_head: SOLOv2Head
mask_head: SOLOv2MaskHead
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
dcn_v2_stages: [3, 4, 5]
variant: d
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
reverse_out: True
SOLOv2Head:
seg_feat_channels: 256
stacked_convs: 3
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 128
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
dcn_v2_stages: [2,]
drop_block: True
SOLOv2MaskHead:
in_channels: 128
out_channels: 128
start_level: 0
end_level: 3
SOLOv2Loss:
ins_loss_weight: 3.0
focal_loss_gamma: 2.0
focal_loss_alpha: 0.25
MaskMatrixNMS:
pre_nms_top_n: 500
post_nms_top_n: 100
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: 'solov2_light_reader.yml'
...@@ -77,6 +77,7 @@ OptimizerBuilder: ...@@ -77,6 +77,7 @@ OptimizerBuilder:
_READER_: 'solov2_reader.yml' _READER_: 'solov2_reader.yml'
TrainReader: TrainReader:
batch_size: 2
sample_transforms: sample_transforms:
- !DecodeImage - !DecodeImage
to_rgb: true to_rgb: true
......
...@@ -71,6 +71,7 @@ OptimizerBuilder: ...@@ -71,6 +71,7 @@ OptimizerBuilder:
_READER_: 'solov2_reader.yml' _READER_: 'solov2_reader.yml'
TrainReader: TrainReader:
batch_size: 2
sample_transforms: sample_transforms:
- !DecodeImage - !DecodeImage
to_rgb: true to_rgb: true
......
...@@ -430,7 +430,7 @@ def segm2out(results, clsid2catid, thresh_binarize=0.5): ...@@ -430,7 +430,7 @@ def segm2out(results, clsid2catid, thresh_binarize=0.5):
# for each batch # for each batch
for t in results: for t in results:
segms = t['segm'][0] segms = t['segm'][0].astype(np.uint8)
clsid_labels = t['cate_label'][0] clsid_labels = t['cate_label'][0]
clsid_scores = t['cate_score'][0] clsid_scores = t['cate_score'][0]
lengths = segms.shape[0] lengths = segms.shape[0]
...@@ -443,7 +443,7 @@ def segm2out(results, clsid2catid, thresh_binarize=0.5): ...@@ -443,7 +443,7 @@ def segm2out(results, clsid2catid, thresh_binarize=0.5):
im_h = int(im_shape[0]) im_h = int(im_shape[0])
im_w = int(im_shape[1]) im_w = int(im_shape[1])
clsid = int(clsid_labels[i]) clsid = int(clsid_labels[i]) + 1
catid = clsid2catid[clsid] catid = clsid2catid[clsid]
score = clsid_scores[i] score = clsid_scores[i]
mask = segms[i] mask = segms[i]
......
...@@ -19,6 +19,8 @@ from __future__ import unicode_literals ...@@ -19,6 +19,8 @@ from __future__ import unicode_literals
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from scipy import ndimage
import cv2
from .colormap import colormap from .colormap import colormap
...@@ -31,6 +33,7 @@ def visualize_results(image, ...@@ -31,6 +33,7 @@ def visualize_results(image,
threshold=0.5, threshold=0.5,
bbox_results=None, bbox_results=None,
mask_results=None, mask_results=None,
segm_results=None,
lmk_results=None): lmk_results=None):
""" """
Visualize bbox and mask results Visualize bbox and mask results
...@@ -41,6 +44,8 @@ def visualize_results(image, ...@@ -41,6 +44,8 @@ def visualize_results(image,
image = draw_bbox(image, im_id, catid2name, bbox_results, threshold) image = draw_bbox(image, im_id, catid2name, bbox_results, threshold)
if lmk_results: if lmk_results:
image = draw_lmk(image, im_id, lmk_results, threshold) image = draw_lmk(image, im_id, lmk_results, threshold)
if segm_results:
image = draw_segm(image, im_id, catid2name, segm_results, threshold)
return image return image
...@@ -70,6 +75,67 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7): ...@@ -70,6 +75,67 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7):
return Image.fromarray(img_array.astype('uint8')) return Image.fromarray(img_array.astype('uint8'))
def draw_segm(image,
im_id,
catid2name,
segms,
threshold,
alpha=0.7,
draw_box=True):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = colormap(rgb=True)
img_array = np.array(image).astype('float32')
for dt in np.array(segms):
if im_id != dt['image_id']:
continue
segm, score, catid = dt['segmentation'], dt['score'], dt['category_id']
if score < threshold:
continue
import pycocotools.mask as mask_util
mask = mask_util.decode(segm) * 255
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
img_array[idx[0], idx[1], :] *= 1.0 - alpha
img_array[idx[0], idx[1], :] += alpha * color_mask
if not draw_box:
center_y, center_x = ndimage.measurements.center_of_mass(mask)
label_text = "{}".format(catid2name[catid])
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(img_array, label_text, vis_pos,
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255))
else:
mask = mask_util.decode(segm) * 255
sum_x = np.sum(mask, axis=0)
x = np.where(sum_x > 0.5)[0]
sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(img_array, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (catid2name[catid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0],
y0 - t_size[1] - 3),
tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(
img_array,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(img_array.astype('uint8'))
def draw_bbox(image, im_id, catid2name, bboxes, threshold): def draw_bbox(image, im_id, catid2name, bboxes, threshold):
""" """
Draw bbox on image Draw bbox on image
......
...@@ -138,7 +138,7 @@ def main(): ...@@ -138,7 +138,7 @@ def main():
# parse dataset category # parse dataset category
if cfg.metric == 'COCO': if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info
if cfg.metric == 'OID': if cfg.metric == 'OID':
from ppdet.utils.oid_eval import bbox2out, get_category_info from ppdet.utils.oid_eval import bbox2out, get_category_info
if cfg.metric == "VOC": if cfg.metric == "VOC":
...@@ -187,13 +187,15 @@ def main(): ...@@ -187,13 +187,15 @@ def main():
bbox_results = None bbox_results = None
mask_results = None mask_results = None
segm_results = None
lmk_results = None lmk_results = None
if 'bbox' in res: if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res: if 'mask' in res:
mask_results = mask2out([res], clsid2catid, mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution) model.mask_head.resolution)
if 'segm' in res:
segm_results = segm2out([res], clsid2catid)
if 'landmark' in res: if 'landmark' in res:
lmk_results = lmk2out([res], is_bbox_normalized) lmk_results = lmk2out([res], is_bbox_normalized)
...@@ -213,7 +215,7 @@ def main(): ...@@ -213,7 +215,7 @@ def main():
image = visualize_results(image, image = visualize_results(image,
int(im_id), catid2name, int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results, FLAGS.draw_threshold, bbox_results,
mask_results, lmk_results) mask_results, segm_results, lmk_results)
# use VisualDL to log image with bbox # use VisualDL to log image with bbox
if FLAGS.use_vdl: if FLAGS.use_vdl:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册