未验证 提交 f4618562 编写于 作者: G Guanghua Yu 提交者: GitHub

support infer for SOLOv2 (#1552)

* support infer for SOLOv2
上级 c72fea6a
......@@ -77,7 +77,7 @@ PaddleDetection模块化地实现了多种主流目标检测算法,提供了
<li><b>Instance Segmentation</b></li>
<ul>
<li>Mask RCNN</li>
<li>SOLOv2 is coming soon</li>
<li>SOLOv2</li>
</ul>
</ul>
<ul>
......
......@@ -4,7 +4,8 @@
SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy and speed of the SOLOv2.
** Highlights: **
**Highlights:**
- Performance: `Light-R50-VD-DCN-FPN` model reached 38.6 FPS on single Tesla V100, and mask ap on the COCO-val dataset reached 38.8, which increased inference speed by 24%, mAP increased by 2.4 percentage points.
- Training Time: The training time of the model of `solov2_r50_fpn_1x` on Tesla v100 with 8 GPU is only 10 hours.
......@@ -15,7 +16,7 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| R50-FPN | False | 1x | 45.7ms | 35.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml) |
| R50-FPN | True | 3x | 45.7ms | 37.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_3x.yml) |
| R101-VD-FPN | True | 3x | - | 42.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r101_vd_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r101_vd_fpn_3x.yml) |
| R101-VD-FPN | True | 3x | 82.6ms | 42.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r101_vd_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r101_vd_fpn_3x.yml) |
## Enhanced model
| Backbone | Input size | Lr schd | Inf time (V100) | Mask AP | Download | Configs |
......
architecture: SOLOv2
use_gpu: true
max_iters: 270000
snapshot_iter: 30000
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
metric: COCO
weights: output/solov2_light_r50_vd_fpn_dcn_512_3x/model_final
num_classes: 81
use_ema: true
ema_decay: 0.9998
SOLOv2:
backbone: ResNet
fpn: FPN
bbox_head: SOLOv2Head
mask_head: SOLOv2MaskHead
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
dcn_v2_stages: [3, 4, 5]
variant: d
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
reverse_out: True
SOLOv2Head:
seg_feat_channels: 256
stacked_convs: 3
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 128
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
dcn_v2_stages: [2,]
drop_block: True
SOLOv2MaskHead:
in_channels: 128
out_channels: 128
start_level: 0
end_level: 3
SOLOv2Loss:
ins_loss_weight: 3.0
focal_loss_gamma: 2.0
focal_loss_alpha: 0.25
MaskMatrixNMS:
pre_nms_top_n: 500
post_nms_top_n: 100
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: 'solov2_light_reader.yml'
......@@ -77,6 +77,7 @@ OptimizerBuilder:
_READER_: 'solov2_reader.yml'
TrainReader:
batch_size: 2
sample_transforms:
- !DecodeImage
to_rgb: true
......
......@@ -71,6 +71,7 @@ OptimizerBuilder:
_READER_: 'solov2_reader.yml'
TrainReader:
batch_size: 2
sample_transforms:
- !DecodeImage
to_rgb: true
......
......@@ -430,7 +430,7 @@ def segm2out(results, clsid2catid, thresh_binarize=0.5):
# for each batch
for t in results:
segms = t['segm'][0]
segms = t['segm'][0].astype(np.uint8)
clsid_labels = t['cate_label'][0]
clsid_scores = t['cate_score'][0]
lengths = segms.shape[0]
......@@ -443,7 +443,7 @@ def segm2out(results, clsid2catid, thresh_binarize=0.5):
im_h = int(im_shape[0])
im_w = int(im_shape[1])
clsid = int(clsid_labels[i])
clsid = int(clsid_labels[i]) + 1
catid = clsid2catid[clsid]
score = clsid_scores[i]
mask = segms[i]
......
......@@ -19,6 +19,8 @@ from __future__ import unicode_literals
import numpy as np
from PIL import Image, ImageDraw
from scipy import ndimage
import cv2
from .colormap import colormap
......@@ -31,6 +33,7 @@ def visualize_results(image,
threshold=0.5,
bbox_results=None,
mask_results=None,
segm_results=None,
lmk_results=None):
"""
Visualize bbox and mask results
......@@ -41,6 +44,8 @@ def visualize_results(image,
image = draw_bbox(image, im_id, catid2name, bbox_results, threshold)
if lmk_results:
image = draw_lmk(image, im_id, lmk_results, threshold)
if segm_results:
image = draw_segm(image, im_id, catid2name, segm_results, threshold)
return image
......@@ -70,6 +75,67 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7):
return Image.fromarray(img_array.astype('uint8'))
def draw_segm(image,
im_id,
catid2name,
segms,
threshold,
alpha=0.7,
draw_box=True):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = colormap(rgb=True)
img_array = np.array(image).astype('float32')
for dt in np.array(segms):
if im_id != dt['image_id']:
continue
segm, score, catid = dt['segmentation'], dt['score'], dt['category_id']
if score < threshold:
continue
import pycocotools.mask as mask_util
mask = mask_util.decode(segm) * 255
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
img_array[idx[0], idx[1], :] *= 1.0 - alpha
img_array[idx[0], idx[1], :] += alpha * color_mask
if not draw_box:
center_y, center_x = ndimage.measurements.center_of_mass(mask)
label_text = "{}".format(catid2name[catid])
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(img_array, label_text, vis_pos,
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255))
else:
mask = mask_util.decode(segm) * 255
sum_x = np.sum(mask, axis=0)
x = np.where(sum_x > 0.5)[0]
sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(img_array, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (catid2name[catid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0],
y0 - t_size[1] - 3),
tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(
img_array,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(img_array.astype('uint8'))
def draw_bbox(image, im_id, catid2name, bboxes, threshold):
"""
Draw bbox on image
......
......@@ -138,7 +138,7 @@ def main():
# parse dataset category
if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info
if cfg.metric == 'OID':
from ppdet.utils.oid_eval import bbox2out, get_category_info
if cfg.metric == "VOC":
......@@ -187,13 +187,15 @@ def main():
bbox_results = None
mask_results = None
segm_results = None
lmk_results = None
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res:
mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution)
if 'segm' in res:
segm_results = segm2out([res], clsid2catid)
if 'landmark' in res:
lmk_results = lmk2out([res], is_bbox_normalized)
......@@ -213,7 +215,7 @@ def main():
image = visualize_results(image,
int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results,
mask_results, lmk_results)
mask_results, segm_results, lmk_results)
# use VisualDL to log image with bbox
if FLAGS.use_vdl:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册