未验证 提交 dd1e05d7 编写于 作者: G Guanghua Yu 提交者: GitHub

add solov2 model (#1412)

* add solov2 model

* fix train batch size

* add solov2_r101_vd_fpn_3x

* fix batch size and update modelzoo

* refactor code of solov2

* fix deploy/python
上级 41966e49
# SOLOv2 for instance segmentation
## Introduction
SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy and speed of the SOLOv2.
** Highlights: **
- Performance: `Light-R50-VD-DCN-FPN` model reached 38.6 FPS on single Tesla V100, and mask ap on the COCO-val dataset reached 38.8, which increased inference speed by 24%, mAP increased by 2.4 percentage points.
- Training Time: The training time of the model of `solov2_r50_fpn_1x` on Tesla v100 with 8 GPU is only 10 hours.
## Model Zoo
| Backbone | Multi-scale training | Lr schd | Inf time (V100) | Mask AP | Download | Configs |
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| R50-FPN | False | 1x | 45.7ms | 35.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml) |
| R50-FPN | True | 3x | 45.7ms | 37.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_3x.yml) |
| R101-VD-FPN | True | 3x | - | 42.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r101_vd_fpn_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r101_vd_fpn_3x.yml) |
## Enhanced model
| Backbone | Input size | Lr schd | Inf time (V100) | Mask AP | Download | Configs |
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| Light-R50-VD-DCN-FPN | 512 | 3x | 25.9ms | 38.8 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_light_r50_vd_fpn_dcn_512_3x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_light_r50_vd_fpn_dcn_512_3x.yml) |
## Citations
```
@article{wang2020solov2,
title={SOLOv2: Dynamic, Faster and Stronger},
author={Wang, Xinlong and Zhang, Rufeng and Kong, Tao and Li, Lei and Shen, Chunhua},
journal={arXiv preprint arXiv:2003.10152},
year={2020}
}
```
TrainReader:
batch_size: 2
worker_num: 2
inputs_def:
fields: ['image', 'im_id', 'gt_segm']
dataset:
!COCODataSet
dataset_dir: dataset/coco
anno_path: annotations/instances_train2017.json
image_dir: train2017
sample_transforms:
- !DecodeImage
to_rgb: true
- !Poly2Mask {}
- !ColorDistort {}
- !RandomCrop
is_mask_crop: True
- !ResizeImage
target_size: [352, 384, 416, 448, 480, 512]
max_size: 852
interp: 1
use_cv2: true
resize_box: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
- !Gt2Solov2Target
num_grids: [40, 36, 24, 16, 12]
scale_ranges: [[1, 64], [32, 128], [64, 256], [128, 512], [256, 2048]]
coord_sigma: 0.2
shuffle: True
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !ResizeImage
interp: 1
max_size: 852
target_size: 512
use_cv2: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
# only support batch_size=1 when evaluation
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: dataset/coco/annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
- !ResizeImage
interp: 1
max_size: 852
target_size: 512
use_cv2: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
architecture: SOLOv2
use_gpu: true
max_iters: 270000
snapshot_iter: 30000
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
metric: COCO
weights: output/solov2_r101_vd_fpn_3x/model_final
num_classes: 81
use_ema: true
ema_decay: 0.9998
SOLOv2:
backbone: ResNet
fpn: FPN
bbox_head: SOLOv2Head
mask_head: SOLOv2MaskHead
ResNet:
depth: 101
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
dcn_v2_stages: [3, 4, 5]
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
reverse_out: True
SOLOv2Head:
seg_feat_channels: 512
stacked_convs: 4
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 256
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
dcn_v2_stages: [0, 1, 2, 3]
SOLOv2MaskHead:
in_channels: 128
out_channels: 256
start_level: 0
end_level: 3
use_dcn_in_tower: True
SOLOv2Loss:
ins_loss_weight: 3.0
focal_loss_gamma: 2.0
focal_loss_alpha: 0.25
MaskMatrixNMS:
pre_nms_top_n: 500
post_nms_top_n: 100
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: 'solov2_reader.yml'
TrainReader:
sample_transforms:
- !DecodeImage
to_rgb: true
- !Poly2Mask {}
- !ResizeImage
target_size: [640, 672, 704, 736, 768, 800]
max_size: 1333
interp: 1
use_cv2: true
resize_box: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
- !Gt2Solov2Target
num_grids: [40, 36, 24, 16, 12]
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]]
coord_sigma: 0.2
architecture: SOLOv2
use_gpu: true
max_iters: 90000
snapshot_iter: 10000
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/solov2_r50_fpn_1x/model_final
num_classes: 81
SOLOv2:
backbone: ResNet
fpn: FPN
bbox_head: SOLOv2Head
mask_head: SOLOv2MaskHead
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
reverse_out: True
SOLOv2Head:
seg_feat_channels: 512
stacked_convs: 4
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 256
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
SOLOv2MaskHead:
in_channels: 128
out_channels: 256
start_level: 0
end_level: 3
SOLOv2Loss:
ins_loss_weight: 3.0
focal_loss_gamma: 2.0
focal_loss_alpha: 0.25
MaskMatrixNMS:
pre_nms_top_n: 500
post_nms_top_n: 100
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: 'solov2_reader.yml'
architecture: SOLOv2
use_gpu: true
max_iters: 270000
snapshot_iter: 30000
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/solov2/solov2_r50_fpn_3x/model_final
num_classes: 81
SOLOv2:
backbone: ResNet
fpn: FPN
bbox_head: SOLOv2Head
mask_head: SOLOv2MaskHead
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
reverse_out: True
SOLOv2Head:
seg_feat_channels: 512
stacked_convs: 4
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 256
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
SOLOv2MaskHead:
in_channels: 128
out_channels: 256
start_level: 0
end_level: 3
SOLOv2Loss:
ins_loss_weight: 3.0
focal_loss_gamma: 2.0
focal_loss_alpha: 0.25
MaskMatrixNMS:
pre_nms_top_n: 500
post_nms_top_n: 100
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: 'solov2_reader.yml'
TrainReader:
sample_transforms:
- !DecodeImage
to_rgb: true
- !Poly2Mask {}
- !ResizeImage
target_size: [640, 672, 704, 736, 768, 800]
max_size: 1333
interp: 1
use_cv2: true
resize_box: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
- !Gt2Solov2Target
num_grids: [40, 36, 24, 16, 12]
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]]
coord_sigma: 0.2
TrainReader:
batch_size: 2
worker_num: 2
inputs_def:
fields: ['image', 'im_id', 'gt_segm']
dataset:
!COCODataSet
dataset_dir: dataset/coco
anno_path: annotations/instances_train2017.json
image_dir: train2017
sample_transforms:
- !DecodeImage
to_rgb: true
- !Poly2Mask {}
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
resize_box: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
- !Gt2Solov2Target
num_grids: [40, 36, 24, 16, 12]
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]]
coord_sigma: 0.2
shuffle: True
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
# only support batch_size=1 when evaluation
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: dataset/coco/annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import cv2
import numpy as np
# Global dictionary
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
'SOLOv2',
}
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str/np.ndarray): path of image/ np.ndarray read by cv2
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info['origin_shape'] = im.shape[:2]
im_info['resize_shape'] = im.shape[:2]
else:
im = im_file
im_info['origin_shape'] = im.shape[:2]
im_info['resize_shape'] = im.shape[:2]
return im, im_info
class Resize(object):
"""resize image by target_size and max_size
Args:
arch (str): model type
target_size (int): the target size of image
max_size (int): the max size of image
use_cv2 (bool): whether us cv2
image_shape (list): input shape of model
interp (int): method of resize
"""
def __init__(self,
arch,
target_size,
max_size,
use_cv2=True,
image_shape=None,
interp=cv2.INTER_LINEAR,
resize_box=False):
self.target_size = target_size
self.max_size = max_size
self.image_shape = image_shape
self.arch = arch
self.use_cv2 = use_cv2
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel = im.shape[2]
im_scale_x, im_scale_y = self.generate_scale(im)
if self.use_cv2:
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
else:
resize_w = int(im_scale_x * float(im.shape[1]))
resize_h = int(im_scale_y * float(im.shape[0]))
if self.max_size != 0:
raise TypeError(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.')
im = im.astype('uint8')
im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im)
# padding im when image_shape fixed by infer_cfg.yml
if self.max_size != 0 and self.image_shape is not None:
padding_im = np.zeros(
(self.max_size, self.max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im
im = padding_im
im_info['scale'] = [im_scale_x, im_scale_y]
im_info['resize_shape'] = im.shape[:2]
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.max_size != 0 and self.arch in RESIZE_SCALE_SET:
im_size_min = np.min(origin_shape[0:2])
im_size_max = np.max(origin_shape[0:2])
im_scale = float(self.target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > self.max_size:
im_scale = float(self.max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
im_scale_x = float(self.target_size) / float(origin_shape[1])
im_scale_y = float(self.target_size) / float(origin_shape[0])
return im_scale_x, im_scale_y
class Normalize(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def __init__(self, mean, std, is_scale=True, is_channel_first=False):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.is_channel_first = is_channel_first
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_channel_first:
mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
std = np.array(self.std)[:, np.newaxis, np.newaxis]
else:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, to_bgr=False, channel_first=True):
self.to_bgr = to_bgr
self.channel_first = channel_first
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if self.channel_first:
im = im.transpose((2, 0, 1)).copy()
if self.to_bgr:
im = im[[2, 1, 0], :, :]
return im, im_info
class PadStride(object):
""" padding image for model with FPN
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return im
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
im_info['pad_shape'] = padding_im.shape[1:]
return padding_im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale': [1., 1.],
'origin_shape': None,
'resize_shape': None,
'pad_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
im = np.array((im, )).astype('float32')
return im, im_info
......@@ -18,20 +18,22 @@ from __future__ import division
import cv2
import numpy as np
from PIL import Image, ImageDraw
from scipy import ndimage
def visualize_box_mask(im, results, labels, mask_resolution=14):
"""
def visualize_box_mask(im, results, labels, mask_resolution=14, threshold=0.5):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
if isinstance(im, str):
im = Image.open(im).convert('RGB')
......@@ -46,15 +48,23 @@ def visualize_box_mask(im, results, labels, mask_resolution=14):
resolution=mask_resolution)
if 'boxes' in results:
im = draw_box(im, results['boxes'], labels)
if 'segm' in results:
im = draw_segm(
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im
def get_color_map_list(num_classes):
"""
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
......@@ -71,9 +81,9 @@ def get_color_map_list(num_classes):
def expand_boxes(boxes, scale=0.0):
"""
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box
boxes (np.ndarray): shape:[N,4], N:number of box,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
......@@ -94,17 +104,17 @@ def expand_boxes(boxes, scale=0.0):
def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
labels (list): labels:['class1', ..., 'classn']
resolution (int): shape of a mask is:[resolution, resolution]
threshold (float): threshold of mask
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
color_list = get_color_map_list(len(labels))
scale = (resolution + 2.0) / resolution
......@@ -149,14 +159,14 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
def draw_box(im, np_boxes, labels):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
......@@ -186,3 +196,41 @@ def draw_box(im, np_boxes, labels):
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def draw_segm(im,
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = get_color_map_list(len(labels))
im = np.array(im).astype('float32')
clsid2color = {}
np_segms = np_segms.astype(np.uint8)
for i in range(np_segms.shape[0]):
mask, score, clsid = np_segms[i], np_score[i], np_label[i] + 1
if score < threshold:
continue
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask
center_y, center_x = ndimage.measurements.center_of_mass(mask)
label_text = "{}".format(labels[clsid])
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(im, label_text, vis_pos, cv2.FONT_HERSHEY_COMPLEX, 0.3,
(255, 255, 255))
return Image.fromarray(im.astype('uint8'))
......@@ -24,6 +24,7 @@ except Exception:
import logging
import cv2
import numpy as np
from scipy import ndimage
from .operators import register_op, BaseOperator
from .op_helper import jaccard_overlap, gaussian2D
......@@ -37,6 +38,7 @@ __all__ = [
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
'Gt2Solov2Target',
]
......@@ -88,6 +90,13 @@ class PadBatch(BaseOperator):
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
if 'gt_segm' in data.keys() and data['gt_segm'] is not None:
gt_segm = data['gt_segm']
padding_segm = np.zeros(
(gt_segm.shape[0], max_shape[1], max_shape[2]),
dtype=np.uint8)
padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm
return samples
......@@ -590,3 +599,154 @@ class Gt2TTFTarget(BaseOperator):
heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
masked_heatmap, masked_gaussian)
return heatmap
@register_op
class Gt2Solov2Target(BaseOperator):
"""Assign mask target and labels in SOLOv2 network.
Args:
num_grids (list): The list of feature map grids size.
scale_ranges (list): The list of mask boundary range.
coord_sigma (float): The coefficient of coordinate area length.
sampling_ratio (float): The ratio of down sampling.
"""
def __init__(self,
num_grids=[40, 36, 24, 16, 12],
scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
[384, 2048]],
coord_sigma=0.2,
sampling_ratio=4.0):
super(Gt2Solov2Target, self).__init__()
self.num_grids = num_grids
self.scale_ranges = scale_ranges
self.coord_sigma = coord_sigma
self.sampling_ratio = sampling_ratio
def _scale_size(self, im, scale):
h, w = im.shape[:2]
new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
resized_img = cv2.resize(
im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
return resized_img
def __call__(self, samples, context=None):
sample_id = 0
for sample in samples:
gt_bboxes_raw = sample['gt_bbox']
gt_labels_raw = sample['gt_class']
im_c, im_h, im_w = sample['image'].shape[:]
gt_masks_raw = sample['gt_segm'].astype(np.uint8)
mask_feat_size = [
int(im_h / self.sampling_ratio), int(im_w / self.sampling_ratio)
]
gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
(gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_ind_label_list = []
idx = 0
for (lower_bound, upper_bound), num_grid \
in zip(self.scale_ranges, self.num_grids):
hit_indices = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero()[0]
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
if num_ins == 0:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(
[sample_id * num_grid * num_grid + 0])
idx += 1
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices, ...]
half_ws = 0.5 * (
gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
half_hs = 0.5 * (
gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
for seg_mask, gt_label, half_h, half_w in zip(
gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() == 0:
continue
# mass center
upsampled_size = (mask_feat_size[0] * 4,
mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(
seg_mask)
coord_w = int(
(center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int(
(center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0,
int(((center_h - half_h) / upsampled_size[0])
// (1. / num_grid)))
down_box = min(num_grid - 1,
int(((center_h + half_h) / upsampled_size[0])
// (1. / num_grid)))
left_box = max(0,
int(((center_w - half_w) / upsampled_size[1])
// (1. / num_grid)))
right_box = min(num_grid - 1,
int(((center_w + half_w) /
upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h - 1)
down = min(down_box, coord_h + 1)
left = max(coord_w - 1, left_box)
right = min(right_box, coord_w + 1)
cate_label[top:(down + 1), left:(right + 1)] = gt_label
seg_mask = self._scale_size(
seg_mask, scale=1. / self.sampling_ratio)
for i in range(top, down + 1):
for j in range(left, right + 1):
label = int(i * num_grid + j)
cur_ins_label = np.zeros(
[mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(
[sample_id * num_grid * num_grid + label])
if ins_label == []:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(
[sample_id * num_grid * num_grid + 0])
else:
ins_label = np.stack(ins_label, axis=0)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(grid_order)
assert len(grid_order) > 0
idx += 1
ins_ind_labels = np.concatenate([
ins_ind_labels_level_img
for ins_ind_labels_level_img in ins_ind_label_list
])
fg_num = np.sum(ins_ind_labels)
sample['fg_num'] = fg_num
sample_id += 1
return samples
......@@ -272,7 +272,8 @@ class ResizeImage(BaseOperator):
target_size=0,
max_size=0,
interp=cv2.INTER_LINEAR,
use_cv2=True):
use_cv2=True,
resize_box=False):
"""
Rescale image to the specified target size, and capped at max_size
if max_size != 0.
......@@ -285,11 +286,13 @@ class ResizeImage(BaseOperator):
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
resize_box (bool): whether resize ground truth bbox annotations.
"""
super(ResizeImage, self).__init__()
self.max_size = int(max_size)
self.interp = int(interp)
self.use_cv2 = use_cv2
self.resize_box = resize_box
if not (isinstance(target_size, int) or isinstance(target_size, list)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List, now is {}".
......@@ -348,18 +351,6 @@ class ResizeImage(BaseOperator):
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
if 'semantic' in sample.keys() and sample['semantic'] is not None:
semantic = sample['semantic']
semantic = cv2.resize(
semantic.astype('float32'),
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
semantic = np.asarray(semantic).astype('int32')
semantic = np.expand_dims(semantic, 0)
sample['semantic'] = semantic
else:
if self.max_size != 0:
raise TypeError(
......@@ -370,6 +361,38 @@ class ResizeImage(BaseOperator):
im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im)
sample['image'] = im
sample['scale_factor'] = [im_scale_x, im_scale_y] * 2
if 'gt_bbox' in sample and self.resize_box and len(sample[
'gt_bbox']) > 0:
bboxes = sample['gt_bbox'] * sample['scale_factor']
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, resize_w - 1)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, resize_h - 1)
sample['gt_bbox'] = bboxes
if 'semantic' in sample.keys() and sample['semantic'] is not None:
semantic = sample['semantic']
semantic = cv2.resize(
semantic.astype('float32'),
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
semantic = np.asarray(semantic).astype('int32')
semantic = np.expand_dims(semantic, 0)
sample['semantic'] = semantic
if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
masks = [
cv2.resize(
gt_segm,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_NEAREST)
for gt_segm in sample['gt_segm']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
......@@ -473,7 +496,6 @@ class RandomFlipImage(BaseOperator):
if self.is_mask_flip and len(sample['gt_poly']) != 0:
sample['gt_poly'] = self.flip_segms(sample['gt_poly'],
height, width)
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = self.flip_keypoint(
sample['gt_keypoint'], width)
......@@ -482,6 +504,9 @@ class RandomFlipImage(BaseOperator):
'semantic'] is not None:
sample['semantic'] = sample['semantic'][:, ::-1]
if 'gt_segm' in sample.keys() and sample['gt_segm'] is not None:
sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
sample['flipped'] = True
sample['image'] = im
sample = samples if batch_input else samples[0]
......@@ -1953,6 +1978,12 @@ class RandomCrop(BaseOperator):
sample['gt_poly'] = valid_polys
else:
sample['gt_poly'] = crop_polys
if 'gt_segm' in sample:
sample['gt_segm'] = self._crop_segm(sample['gt_segm'],
crop_box)
sample['gt_segm'] = np.take(
sample['gt_segm'], valid_ids, axis=0)
sample['image'] = self._crop_image(sample['image'], crop_box)
sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
sample['gt_class'] = np.take(
......@@ -2000,6 +2031,10 @@ class RandomCrop(BaseOperator):
x1, y1, x2, y2 = crop
return img[y1:y2, x1:x2, :]
def _crop_segm(self, segm, crop):
x1, y1, x2, y2 = crop
return segm[:, y1:y2, x1:x2]
@register_op
class PadBox(BaseOperator):
......@@ -2555,3 +2590,41 @@ class DebugVisibleImage(BaseOperator):
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
@register_op
class Poly2Mask(BaseOperator):
"""
gt poly to mask annotations
"""
def __init__(self):
super(Poly2Mask, self).__init__()
import pycocotools.mask as maskUtils
self.maskutils = maskUtils
def _poly2mask(self, mask_ann, img_h, img_w):
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
rle = self.maskutils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = self.maskutils.decode(rle)
return mask
def __call__(self, sample, context=None):
assert 'gt_poly' in sample
im_h = sample['h']
im_w = sample['w']
masks = [
self._poly2mask(gt_poly, im_h, im_w)
for gt_poly in sample['gt_poly']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
......@@ -22,6 +22,7 @@ from . import roi_extractors
from . import roi_heads
from . import ops
from . import target_assigners
from . import mask_head
from .anchor_heads import *
from .architectures import *
......@@ -30,3 +31,4 @@ from .roi_extractors import *
from .roi_heads import *
from .ops import *
from .target_assigners import *
from .mask_head import *
......@@ -21,6 +21,7 @@ from . import fcos_head
from . import corner_head
from . import efficient_head
from . import ttf_head
from . import solov2_head
from .rpn_head import *
from .yolo_head import *
......@@ -29,3 +30,4 @@ from .fcos_head import *
from .corner_head import *
from .efficient_head import *
from .ttf_head import *
from .solov2_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import ConvNorm, DeformConvNorm, MaskMatrixNMS, DropBlock
from ppdet.core.workspace import register
from ppdet.utils.check import check_version
from six.moves import zip
import numpy as np
__all__ = ['SOLOv2Head']
@register
class SOLOv2Head(object):
"""
Head block for SOLOv2 network
Args:
num_classes (int): Number of output classes.
seg_feat_channels (int): Num_filters of kernel & categroy branch convolution operation.
stacked_convs (int): Times of convolution operation.
num_grids (list[int]): List of feature map grids size.
kernel_out_channels (int): Number of output channels in kernel branch.
dcn_v2_stages (list): Which stage use dcn v2 in tower.
segm_strides (list[int]): List of segmentation area stride.
solov2_loss (object): SOLOv2Loss instance.
score_threshold (float): Threshold of categroy score.
mask_nms (object): MaskMatrixNMS instance.
drop_block (bool): Whether use drop_block or not.
"""
__inject__ = ['solov2_loss', 'mask_nms']
__shared__ = ['num_classes']
def __init__(self,
num_classes=80,
seg_feat_channels=256,
stacked_convs=4,
num_grids=[40, 36, 24, 16, 12],
kernel_out_channels=256,
dcn_v2_stages=[],
segm_strides=[8, 8, 16, 32, 32],
solov2_loss=None,
score_threshold=0.1,
mask_threshold=0.5,
mask_nms=MaskMatrixNMS(
update_threshold=0.05,
pre_nms_top_n=500,
post_nms_top_n=100,
kernel='gaussian',
sigma=2.0).__dict__,
drop_block=False):
check_version('2.0.0')
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.kernel_out_channels = kernel_out_channels
self.dcn_v2_stages = dcn_v2_stages
self.segm_strides = segm_strides
self.solov2_loss = solov2_loss
self.mask_nms = mask_nms
self.score_threshold = score_threshold
self.mask_threshold = mask_threshold
self.drop_block = drop_block
self.conv_type = [ConvNorm, DeformConvNorm]
if isinstance(mask_nms, dict):
self.mask_nms = MaskMatrixNMS(**mask_nms)
def _conv_pred(self, conv_feat, num_filters, is_test, name, name_feat=None):
for i in range(self.stacked_convs):
if i in self.dcn_v2_stages:
conv_func = self.conv_type[1]
else:
conv_func = self.conv_type[0]
conv_feat = conv_func(
input=conv_feat,
num_filters=self.seg_feat_channels,
filter_size=3,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name='{}.{}.gn'.format(name, i),
name='{}.{}'.format(name, i))
if name_feat == 'bbox_head.solo_cate':
bias_init = float(-np.log((1 - 0.01) / 0.01))
bias_attr = ParamAttr(
name="{}.bias".format(name_feat),
initializer=fluid.initializer.Constant(value=bias_init))
else:
bias_attr = ParamAttr(name="{}.bias".format(name_feat))
if self.drop_block:
conv_feat = DropBlock(
conv_feat, block_size=3, keep_prob=0.9, is_test=is_test)
conv_feat = fluid.layers.conv2d(
input=conv_feat,
num_filters=num_filters,
filter_size=3,
stride=1,
padding=1,
param_attr=ParamAttr(
name="{}.weight".format(name_feat),
initializer=fluid.initializer.NormalInitializer(scale=0.01)),
bias_attr=bias_attr,
name=name + '_feat_')
return conv_feat
def _points_nms(self, heat, kernel=2):
hmax = fluid.layers.pool2d(
input=heat, pool_size=kernel, pool_type='max', pool_padding=1)
keep = fluid.layers.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
return heat * keep
def _split_feats(self, feats):
return (paddle.nn.functional.interpolate(
feats[0],
scale_factor=0.5,
align_corners=False,
align_mode=0,
mode='bilinear'), feats[1], feats[2], feats[3],
paddle.nn.functional.interpolate(
feats[4],
size=fluid.layers.shape(feats[3])[-2:],
mode='bilinear',
align_corners=False,
align_mode=0))
def get_outputs(self, input, is_eval=False):
"""
Get SOLOv2 head output
Args:
input (list): List of Variables, output of backbone or neck stages
is_eval (bool): whether in train or test mode
Returns:
cate_pred_list (list): Variables of each category branch layer
kernel_pred_list (list): Variables of each kernel branch layer
"""
feats = self._split_feats(input)
cate_pred_list = []
kernel_pred_list = []
for idx in range(len(self.seg_num_grids)):
cate_pred, kernel_pred = self._get_output_single(
feats[idx], idx, is_eval=is_eval)
cate_pred_list.append(cate_pred)
kernel_pred_list.append(kernel_pred)
return cate_pred_list, kernel_pred_list
def _get_output_single(self, input, idx, is_eval=False):
ins_kernel_feat = input
# CoordConv
x_range = paddle.linspace(
-1, 1, fluid.layers.shape(ins_kernel_feat)[-1], dtype='float32')
y_range = paddle.linspace(
-1, 1, fluid.layers.shape(ins_kernel_feat)[-2], dtype='float32')
y, x = paddle.tensor.meshgrid([y_range, x_range])
x = fluid.layers.unsqueeze(x, [0, 1])
y = fluid.layers.unsqueeze(y, [0, 1])
y = fluid.layers.expand(
y, expand_times=[fluid.layers.shape(ins_kernel_feat)[0], 1, 1, 1])
x = fluid.layers.expand(
x, expand_times=[fluid.layers.shape(ins_kernel_feat)[0], 1, 1, 1])
coord_feat = fluid.layers.concat([x, y], axis=1)
ins_kernel_feat = fluid.layers.concat(
[ins_kernel_feat, coord_feat], axis=1)
# kernel branch
kernel_feat = ins_kernel_feat
seg_num_grid = self.seg_num_grids[idx]
kernel_feat = paddle.nn.functional.interpolate(
kernel_feat,
size=[seg_num_grid, seg_num_grid],
mode='bilinear',
align_corners=False,
align_mode=0)
cate_feat = kernel_feat[:, :-2, :, :]
kernel_pred = self._conv_pred(
kernel_feat,
self.kernel_out_channels,
is_eval,
name='bbox_head.kernel_convs',
name_feat='bbox_head.solo_kernel')
# cate branch
cate_pred = self._conv_pred(
cate_feat,
self.cate_out_channels,
is_eval,
name='bbox_head.cate_convs',
name_feat='bbox_head.solo_cate')
if is_eval:
cate_pred = self._points_nms(
fluid.layers.sigmoid(cate_pred), kernel=2)
cate_pred = fluid.layers.transpose(cate_pred, [0, 2, 3, 1])
return cate_pred, kernel_pred
def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels,
cate_labels, grid_order_list, fg_num):
"""
Get loss of network of SOLOv2.
Args:
cate_preds (list): Variable list of categroy branch output.
kernel_preds (list): Variable list of kernel branch output.
ins_pred (list): Variable list of instance branch output.
ins_labels (list): List of instance labels pre batch.
cate_labels (list): List of categroy labels pre batch.
grid_order_list (list): List of index in pre grid.
fg_num (int): Number of positive samples in a mini-batch.
Returns:
loss_ins (Variable): The instance loss Variable of SOLOv2 network.
loss_cate (Variable): The category loss Variable of SOLOv2 network.
"""
new_kernel_preds = []
pad_length_list = []
for kernel_preds_level, grid_orders_level in zip(kernel_preds,
grid_order_list):
reshape_pred = fluid.layers.reshape(
kernel_preds_level,
shape=(fluid.layers.shape(kernel_preds_level)[0],
fluid.layers.shape(kernel_preds_level)[1], -1))
reshape_pred = fluid.layers.transpose(reshape_pred, [0, 2, 1])
reshape_pred = fluid.layers.reshape(
reshape_pred, shape=(-1, fluid.layers.shape(reshape_pred)[2]))
gathered_pred = fluid.layers.gather(
reshape_pred, index=grid_orders_level)
gathered_pred = fluid.layers.lod_reset(gathered_pred,
grid_orders_level)
pad_value = fluid.layers.assign(input=np.array(
[0.0], dtype=np.float32))
pad_pred, pad_length = fluid.layers.sequence_pad(
gathered_pred, pad_value=pad_value)
new_kernel_preds.append(pad_pred)
pad_length_list.append(pad_length)
# generate masks
ins_pred_list = []
for kernel_pred, pad_length in zip(new_kernel_preds, pad_length_list):
cur_ins_pred = ins_pred
cur_ins_pred = fluid.layers.reshape(
cur_ins_pred,
shape=(fluid.layers.shape(cur_ins_pred)[0],
fluid.layers.shape(cur_ins_pred)[1], -1))
ins_pred_conv = paddle.matmul(kernel_pred, cur_ins_pred)
cur_ins_pred = fluid.layers.reshape(
ins_pred_conv,
shape=(fluid.layers.shape(ins_pred_conv)[0],
fluid.layers.shape(ins_pred_conv)[1],
fluid.layers.shape(ins_pred)[-2],
fluid.layers.shape(ins_pred)[-1]))
cur_ins_pred = fluid.layers.sequence_unpad(cur_ins_pred, pad_length)
ins_pred_list.append(cur_ins_pred)
num_ins = fluid.layers.reduce_sum(fg_num)
cate_preds = [
fluid.layers.reshape(
fluid.layers.transpose(cate_pred, [0, 2, 3, 1]),
shape=(-1, self.cate_out_channels)) for cate_pred in cate_preds
]
flatten_cate_preds = fluid.layers.concat(cate_preds)
new_cate_labels = []
cate_labels = fluid.layers.concat(cate_labels)
cate_labels = fluid.layers.unsqueeze(cate_labels, 1)
loss_ins, loss_cate = self.solov2_loss(
ins_pred_list, ins_labels, flatten_cate_preds, cate_labels, num_ins)
return {'loss_ins': loss_ins, 'loss_cate': loss_cate}
def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_info):
"""
Get prediction result of SOLOv2 network
Args:
cate_preds (list): List of Variables, output of categroy branch.
kernel_preds (list): List of Variables, output of kernel branch.
seg_pred (list): List of Variables, output of mask head stages.
im_info(Variables): [h, w, scale] for input images.
Returns:
seg_masks (Variable): The prediction segmentation.
cate_labels (Variable): The prediction categroy label of each segmentation.
seg_masks (Variable): The prediction score of each segmentation.
"""
num_levels = len(cate_preds)
featmap_size = fluid.layers.shape(seg_pred)[-2:]
seg_masks_list = []
cate_labels_list = []
cate_scores_list = []
cate_preds = [cate_pred * 1.0 for cate_pred in cate_preds]
kernel_preds = [kernel_pred * 1.0 for kernel_pred in kernel_preds]
# Currently only supports batch size == 1
for idx in range(1):
cate_pred_list = [
fluid.layers.reshape(
cate_preds[i][idx], shape=(-1, self.cate_out_channels))
for i in range(num_levels)
]
seg_pred_list = seg_pred
kernel_pred_list = [
fluid.layers.reshape(
fluid.layers.transpose(kernel_preds[i][idx], [1, 2, 0]),
shape=(-1, self.kernel_out_channels))
for i in range(num_levels)
]
cate_pred_list = fluid.layers.concat(cate_pred_list, axis=0)
kernel_pred_list = fluid.layers.concat(kernel_pred_list, axis=0)
seg_masks, cate_labels, cate_scores = self.get_seg_single(
cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size,
im_info[idx])
return {
"segm": seg_masks,
'cate_label': cate_labels,
'cate_score': cate_scores
}
def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
im_info):
im_scale = im_info[2]
h = fluid.layers.cast(im_info[0], 'int32')
w = fluid.layers.cast(im_info[1], 'int32')
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
inds = fluid.layers.where(cate_preds > self.score_threshold)
cate_preds = fluid.layers.reshape(cate_preds, shape=[-1])
# Prevent empty and increase fake data
ind_a = fluid.layers.cast(fluid.layers.shape(kernel_preds)[0], 'int64')
ind_b = fluid.layers.zeros(shape=[1], dtype='int64')
inds_end = fluid.layers.unsqueeze(
fluid.layers.concat([ind_a, ind_b]), 0)
inds = fluid.layers.concat([inds, inds_end])
kernel_preds_end = fluid.layers.ones(
shape=[1, self.kernel_out_channels], dtype='float32')
kernel_preds = fluid.layers.concat([kernel_preds, kernel_preds_end])
cate_preds = fluid.layers.concat(
[cate_preds, fluid.layers.zeros(
shape=[1], dtype='float32')])
# cate_labels & kernel_preds
cate_labels = inds[:, 1]
kernel_preds = fluid.layers.gather(kernel_preds, index=inds[:, 0])
cate_score_idx = fluid.layers.elementwise_add(inds[:, 0] * 80,
cate_labels)
cate_scores = fluid.layers.gather(cate_preds, index=cate_score_idx)
size_trans = np.power(self.seg_num_grids, 2)
strides = []
for _ind in range(len(self.segm_strides)):
strides.append(
fluid.layers.fill_constant(
shape=[int(size_trans[_ind])],
dtype="int32",
value=self.segm_strides[_ind]))
strides = fluid.layers.concat(strides)
strides = fluid.layers.gather(strides, index=inds[:, 0])
# mask encoding.
kernel_preds = fluid.layers.unsqueeze(kernel_preds, [2, 3])
seg_preds = paddle.nn.functional.conv2d(seg_preds, kernel_preds)
seg_preds = fluid.layers.sigmoid(fluid.layers.squeeze(seg_preds, [0]))
seg_masks = seg_preds > self.mask_threshold
seg_masks = fluid.layers.cast(seg_masks, 'float32')
sum_masks = fluid.layers.reduce_sum(seg_masks, dim=[1, 2])
keep = fluid.layers.where(sum_masks > strides)
keep = fluid.layers.squeeze(keep, axes=[1])
# Prevent empty and increase fake data
keep_other = fluid.layers.concat([
keep, fluid.layers.cast(
fluid.layers.shape(sum_masks)[0] - 1, 'int64')
])
keep_scores = fluid.layers.concat([
keep, fluid.layers.cast(fluid.layers.shape(sum_masks)[0], 'int64')
])
cate_scores_end = fluid.layers.zeros(shape=[1], dtype='float32')
cate_scores = fluid.layers.concat([cate_scores, cate_scores_end])
seg_masks = fluid.layers.gather(seg_masks, index=keep_other)
seg_preds = fluid.layers.gather(seg_preds, index=keep_other)
sum_masks = fluid.layers.gather(sum_masks, index=keep_other)
cate_labels = fluid.layers.gather(cate_labels, index=keep_other)
cate_scores = fluid.layers.gather(cate_scores, index=keep_scores)
# mask scoring.
seg_mul = fluid.layers.cast(seg_preds * seg_masks, 'float32')
seg_scores = fluid.layers.reduce_sum(seg_mul, dim=[1, 2]) / sum_masks
cate_scores *= seg_scores
# Matrix NMS
seg_preds, cate_scores, cate_labels = self.mask_nms(
seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=sum_masks)
ori_shape = im_info[:2] / im_scale + 0.5
ori_shape = fluid.layers.cast(ori_shape, 'int32')
seg_preds = paddle.nn.functional.interpolate(
fluid.layers.unsqueeze(seg_preds, 0),
size=upsampled_size_out,
mode='bilinear',
align_corners=False,
align_mode=0)[:, :, :h, :w]
seg_masks = fluid.layers.squeeze(
paddle.nn.functional.interpolate(
seg_preds,
size=ori_shape[:2],
mode='bilinear',
align_corners=False,
align_mode=0),
axes=[0])
# TODO: convert uint8
seg_masks = fluid.layers.cast(seg_masks > self.mask_threshold, 'int32')
return seg_masks, cate_labels, cate_scores
......@@ -29,6 +29,7 @@ from . import fcos
from . import cornernet_squeeze
from . import ttfnet
from . import htc
from . import solov2
from .faster_rcnn import *
from .mask_rcnn import *
......@@ -45,3 +46,4 @@ from .fcos import *
from .cornernet_squeeze import *
from .ttfnet import *
from .htc import *
from .solov2 import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['SOLOv2']
@register
class SOLOv2(object):
"""
SOLOv2 network, see https://arxiv.org/abs/2003.10152
Args:
backbone (object): an backbone instance
fpn (object): feature pyramid network instance
bbox_head (object): an `SOLOv2Head` instance
mask_head (object): an `SOLOv2MaskHead` instance
"""
__category__ = 'architecture'
__inject__ = ['backbone', 'fpn', 'bbox_head', 'mask_head']
def __init__(self,
backbone,
fpn=None,
bbox_head='SOLOv2Head',
mask_head='SOLOv2MaskHead'):
super(SOLOv2, self).__init__()
self.backbone = backbone
self.fpn = fpn
self.bbox_head = bbox_head
self.mask_head = mask_head
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
mixed_precision_enabled = mixed_precision_global_state() is not None
# cast inputs to FP16
if mixed_precision_enabled:
im = fluid.layers.cast(im, 'float16')
body_feats = self.backbone(im)
if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats)
if isinstance(body_feats, OrderedDict):
body_feat_names = list(body_feats.keys())
body_feats = [body_feats[name] for name in body_feat_names]
# cast features back to FP32
if mixed_precision_enabled:
body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats]
mask_feat_pred = self.mask_head.get_output(body_feats)
if mode == 'train':
ins_labels = []
cate_labels = []
grid_orders = []
fg_num = feed_vars['fg_num']
for i in range(self.num_level):
ins_label = 'ins_label{}'.format(i)
if ins_label in feed_vars:
ins_labels.append(feed_vars[ins_label])
cate_label = 'cate_label{}'.format(i)
if cate_label in feed_vars:
cate_labels.append(feed_vars[cate_label])
grid_order = 'grid_order{}'.format(i)
if grid_order in feed_vars:
grid_orders.append(feed_vars[grid_order])
cate_preds, kernel_preds = self.bbox_head.get_outputs(body_feats)
losses = self.bbox_head.get_loss(cate_preds, kernel_preds,
mask_feat_pred, ins_labels,
cate_labels, grid_orders, fg_num)
total_loss = fluid.layers.sum(list(losses.values()))
losses.update({'loss': total_loss})
return losses
else:
im_info = feed_vars['im_info']
outs = self.bbox_head.get_outputs(body_feats, is_eval=True)
seg_inputs = outs + (mask_feat_pred, im_info)
return self.bbox_head.get_prediction(*seg_inputs)
def _inputs_def(self, image_shape, fields):
im_shape = [None] + image_shape
# yapf: disable
inputs_def = {
'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
}
if 'gt_segm' in fields:
for i in range(self.num_level):
targets_def = {
'ins_label%d' % i: {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
'cate_label%d' % i: {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'grid_order%d' % i: {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
}
inputs_def.update(targets_def)
targets_def = {
'fg_num': {'shape': [None], 'dtype': 'int32', 'lod_level': 0},
}
# yapf: enable
inputs_def.update(targets_def)
return inputs_def
def build_inputs(
self,
image_shape=[3, None, None],
fields=['image', 'im_id', 'gt_segm'], # for train
num_level=5,
use_dataloader=True,
iterable=False):
self.num_level = num_level
inputs_def = self._inputs_def(image_shape, fields)
if 'gt_segm' in fields:
fields.remove('gt_segm')
fields.extend(['fg_num'])
for i in range(num_level):
fields.extend([
'ins_label%d' % i, 'cate_label%d' % i, 'grid_order%d' % i
])
feed_vars = OrderedDict([(key, fluid.data(
name=key,
shape=inputs_def[key]['shape'],
dtype=inputs_def[key]['dtype'],
lod_level=inputs_def[key]['lod_level'])) for key in fields])
loader = fluid.io.DataLoader.from_generator(
feed_list=list(feed_vars.values()),
capacity=16,
use_double_buffer=True,
iterable=iterable) if use_dataloader else None
return feed_vars, loader
def train(self, feed_vars):
return self.build(feed_vars, mode='train')
def eval(self, feed_vars):
return self.build(feed_vars, mode='test')
def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, mode='test')
......@@ -24,6 +24,7 @@ from . import fcos_loss
from . import diou_loss_yolo
from . import iou_aware_loss
from . import ssd_with_lmk_loss
from . import solov2_loss
from .iou_aware_loss import *
from .yolo_loss import *
......@@ -34,4 +35,5 @@ from .iou_loss import *
from .balanced_l1_loss import *
from .fcos_loss import *
from .diou_loss_yolo import *
from .ssd_with_lmk_loss import *
\ No newline at end of file
from .ssd_with_lmk_loss import *
from .solov2_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import fluid
from ppdet.core.workspace import register, serializable
__all__ = ['SOLOv2Loss']
@register
@serializable
class SOLOv2Loss(object):
"""
SOLOv2Loss
Args:
ins_loss_weight (float): Weight of instance loss.
focal_loss_gamma (float): Gamma parameter for focal loss.
focal_loss_alpha (float): Alpha parameter for focal loss.
"""
def __init__(self,
ins_loss_weight=3.0,
focal_loss_gamma=2.0,
focal_loss_alpha=0.25):
self.ins_loss_weight = ins_loss_weight
self.focal_loss_gamma = focal_loss_gamma
self.focal_loss_alpha = focal_loss_alpha
def _dice_loss(self, input, target):
input = fluid.layers.reshape(
input, shape=(fluid.layers.shape(input)[0], -1))
target = fluid.layers.reshape(
target, shape=(fluid.layers.shape(target)[0], -1))
target = fluid.layers.cast(target, 'float32')
a = fluid.layers.reduce_sum(input * target, dim=1)
b = fluid.layers.reduce_sum(input * input, dim=1) + 0.001
c = fluid.layers.reduce_sum(target * target, dim=1) + 0.001
d = (2 * a) / (b + c)
return 1 - d
def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels,
num_ins):
"""
Get loss of network of SOLOv2.
Args:
ins_pred_list (list): Variable list of instance branch output.
ins_label_list (list): List of instance labels pre batch.
cate_preds (list): Concat Variable list of categroy branch output.
cate_labels (list): Concat list of categroy labels pre batch.
num_ins (int): Number of positive samples in a mini-batch.
Returns:
loss_ins (Variable): The instance loss Variable of SOLOv2 network.
loss_cate (Variable): The category loss Variable of SOLOv2 network.
"""
# Ues dice_loss to calculate instance loss
loss_ins = []
total_weights = fluid.layers.zeros(shape=[1], dtype='float32')
for input, target in zip(ins_pred_list, ins_label_list):
weights = fluid.layers.cast(
fluid.layers.reduce_sum(
target, dim=[1, 2]) > 0, 'float32')
input = fluid.layers.sigmoid(input)
dice_out = fluid.layers.elementwise_mul(
self._dice_loss(input, target), weights)
total_weights += fluid.layers.reduce_sum(weights)
loss_ins.append(dice_out)
loss_ins = fluid.layers.reduce_sum(fluid.layers.concat(
loss_ins)) / total_weights
loss_ins = loss_ins * self.ins_loss_weight
# Ues sigmoid_focal_loss to calculate category loss
loss_cate = fluid.layers.sigmoid_focal_loss(
x=cate_preds,
label=cate_labels,
fg_num=num_ins + 1,
gamma=self.focal_loss_gamma,
alpha=self.focal_loss_alpha)
loss_cate = fluid.layers.reduce_sum(loss_cate)
return loss_ins, loss_cate
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import solo_mask_head
from .solo_mask_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import fluid
from ppdet.core.workspace import register
from ppdet.modeling.ops import ConvNorm, DeformConvNorm
__all__ = ['SOLOv2MaskHead']
@register
class SOLOv2MaskHead(object):
"""
MaskHead of SOLOv2
Args:
in_channels (int): The channel number of input variable.
out_channels (int): The channel number of output variable.
start_level (int): The position where the input starts.
end_level (int): The position where the input ends.
use_dcn_in_tower: Whether to use dcn in tower or not.
"""
def __init__(self,
in_channels=128,
out_channels=128,
start_level=0,
end_level=3,
use_dcn_in_tower=False):
super(SOLOv2MaskHead, self).__init__()
assert start_level >= 0 and end_level >= start_level
self.out_channels = out_channels
self.start_level = start_level
self.end_level = end_level
self.in_channels = in_channels
self.use_dcn_in_tower = use_dcn_in_tower
self.conv_type = [ConvNorm, DeformConvNorm]
def _convs_levels(self, conv_feat, level, name=None):
conv_func = self.conv_type[0]
if self.use_dcn_in_tower:
conv_func = self.conv_type[1]
if level == 0:
return conv_func(
input=conv_feat,
num_filters=self.in_channels,
filter_size=3,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name=name + '.conv' + str(level) + '.gn',
name=name + '.conv' + str(level))
for j in range(level):
conv_feat = conv_func(
input=conv_feat,
num_filters=self.in_channels,
filter_size=3,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name=name + '.conv' + str(j) + '.gn',
name=name + '.conv' + str(j))
conv_feat = fluid.layers.resize_bilinear(
conv_feat,
scale=2,
name='upsample' + str(level) + str(j),
align_corners=False,
align_mode=0)
return conv_feat
def _conv_pred(self, conv_feat):
conv_func = self.conv_type[0]
if self.use_dcn_in_tower:
conv_func = self.conv_type[1]
conv_feat = conv_func(
input=conv_feat,
num_filters=self.out_channels,
filter_size=1,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name='mask_feat_head.conv_pred.0.gn',
name='mask_feat_head.conv_pred.0')
return conv_feat
def get_output(self, inputs):
"""
Get SOLOv2MaskHead output.
Args:
inputs(list[Variable]): feature map from each necks with shape of [N, C, H, W]
Returns:
ins_pred(Variable): Output of SOLOv2MaskHead head
"""
range_level = self.end_level - self.start_level + 1
feature_add_all_level = self._convs_levels(
inputs[0], 0, name='mask_feat_head.convs_all_levels.0')
for i in range(1, range_level):
input_p = inputs[i]
if i == (range_level - 1):
input_feat = input_p
x_range = paddle.linspace(
-1, 1, fluid.layers.shape(input_feat)[-1], dtype='float32')
y_range = paddle.linspace(
-1, 1, fluid.layers.shape(input_feat)[-2], dtype='float32')
y, x = paddle.tensor.meshgrid([y_range, x_range])
x = fluid.layers.unsqueeze(x, [0, 1])
y = fluid.layers.unsqueeze(y, [0, 1])
y = fluid.layers.expand(
y,
expand_times=[fluid.layers.shape(input_feat)[0], 1, 1, 1])
x = fluid.layers.expand(
x,
expand_times=[fluid.layers.shape(input_feat)[0], 1, 1, 1])
coord_feat = fluid.layers.concat([x, y], axis=1)
input_p = fluid.layers.concat([input_p, coord_feat], axis=1)
feature_add_all_level = fluid.layers.elementwise_add(
feature_add_all_level,
self._convs_levels(
input_p,
i,
name='mask_feat_head.convs_all_levels.{}'.format(i)))
ins_pred = self._conv_pred(feature_add_all_level)
return ins_pred
......@@ -17,6 +17,7 @@ from numbers import Integral
import math
import six
import paddle
from paddle import fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import NumpyArrayInitializer
......@@ -1263,27 +1264,27 @@ class LibraBBoxAssigner(object):
rois = create_tmp_var(
fluid.default_main_program(),
name=None, #'rois',
name=None,
dtype='float32',
shape=[-1, 4], )
bbox_inside_weights = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_inside_weights',
name=None,
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
bbox_outside_weights = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_outside_weights',
name=None,
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
bbox_targets = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_targets',
name=None,
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
labels_int32 = create_tmp_var(
fluid.default_main_program(),
name=None, #'labels_int32',
name=None,
dtype='int32',
shape=[-1, 1], )
......@@ -1565,3 +1566,138 @@ class RetinaOutputDecoder(object):
self.nms_top_k = pre_nms_top_n
self.keep_top_k = detections_per_im
self.nms_eta = nms_eta
@register
@serializable
class MaskMatrixNMS(object):
"""
Matrix NMS for multi-class masks.
Args:
update_threshold (float): Updated threshold of categroy score in second time.
pre_nms_top_n (int): Number of total instance to be kept per image before NMS
post_nms_top_n (int): Number of total instance to be kept per image after NMS.
kernel (str): 'linear' or 'gaussian'.
sigma (float): std in gaussian method.
Input:
seg_preds (Variable): shape (n, h, w), segmentation feature maps
seg_masks (Variable): shape (n, h, w), segmentation feature maps
cate_labels (Variable): shape (n), mask labels in descending order
cate_scores (Variable): shape (n), mask scores in descending order
sum_masks (Variable): a float tensor of the sum of seg_masks
Returns:
Variable: cate_scores, tensors of shape (n)
"""
def __init__(self,
update_threshold=0.05,
pre_nms_top_n=500,
post_nms_top_n=100,
kernel='gaussian',
sigma=2.0):
super(MaskMatrixNMS, self).__init__()
self.update_threshold = update_threshold
self.pre_nms_top_n = pre_nms_top_n
self.post_nms_top_n = post_nms_top_n
self.kernel = kernel
self.sigma = sigma
def _sort_score(self, scores, top_num):
self.case_scores = scores
def fn_1():
return fluid.layers.topk(self.case_scores, top_num)
def fn_2():
return fluid.layers.argsort(self.case_scores, descending=True)
sort_inds = fluid.layers.case(
pred_fn_pairs=[(fluid.layers.shape(scores)[0] > top_num, fn_1)],
default=fn_2)
return sort_inds
def __call__(self,
seg_preds,
seg_masks,
cate_labels,
cate_scores,
sum_masks=None):
# sort and keep top nms_pre
sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
seg_masks = fluid.layers.gather(seg_masks, index=sort_inds[1])
seg_preds = fluid.layers.gather(seg_preds, index=sort_inds[1])
sum_masks = fluid.layers.gather(sum_masks, index=sort_inds[1])
cate_scores = sort_inds[0]
cate_labels = fluid.layers.gather(cate_labels, index=sort_inds[1])
seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
# inter.
inter_matrix = paddle.mm(seg_masks,
fluid.layers.transpose(seg_masks, [1, 0]))
n_samples = fluid.layers.shape(cate_labels)
# union.
sum_masks_x = fluid.layers.reshape(
fluid.layers.expand(
sum_masks, expand_times=[n_samples]),
shape=[n_samples, n_samples])
# iou.
iou_matrix = (inter_matrix / (sum_masks_x + fluid.layers.transpose(
sum_masks_x, [1, 0]) - inter_matrix))
iou_matrix = paddle.triu(iou_matrix, diagonal=1)
# label_specific matrix.
cate_labels_x = fluid.layers.reshape(
fluid.layers.expand(
cate_labels, expand_times=[n_samples]),
shape=[n_samples, n_samples])
label_matrix = fluid.layers.cast(
(cate_labels_x == fluid.layers.transpose(cate_labels_x, [1, 0])),
'float32')
label_matrix = paddle.triu(label_matrix, diagonal=1)
# IoU compensation
compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
compensate_iou = fluid.layers.reshape(
fluid.layers.expand(
compensate_iou, expand_times=[n_samples]),
shape=[n_samples, n_samples])
compensate_iou = fluid.layers.transpose(compensate_iou, [1, 0])
# IoU decay
decay_iou = iou_matrix * label_matrix
# matrix nms
if self.kernel == 'gaussian':
decay_matrix = fluid.layers.exp(-1 * self.sigma * (decay_iou**2))
compensate_matrix = fluid.layers.exp(-1 * self.sigma *
(compensate_iou**2))
decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
axis=0)
elif self.kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient = paddle.min(decay_matrix, axis=0)
else:
raise NotImplementedError
# update the score.
cate_scores = cate_scores * decay_coefficient
keep = fluid.layers.where(cate_scores >= self.update_threshold)
keep = fluid.layers.squeeze(keep, axes=[1])
# Prevent empty and increase fake data
keep = fluid.layers.concat([
keep, fluid.layers.cast(
fluid.layers.shape(cate_scores)[0] - 1, 'int64')
])
seg_preds = fluid.layers.gather(seg_preds, index=keep)
cate_scores = fluid.layers.gather(cate_scores, index=keep)
cate_labels = fluid.layers.gather(cate_labels, index=keep)
# sort and keep top_k
sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
seg_preds = fluid.layers.gather(seg_preds, index=sort_inds[1])
cate_scores = sort_inds[0]
cate_labels = fluid.layers.gather(cate_labels, index=sort_inds[1])
return seg_preds, cate_scores, cate_labels
......@@ -111,6 +111,10 @@ def mask_eval(results,
resolution,
thresh_binarize=0.5,
save_only=False):
"""
Format the output of mask and get mask ap by coco api evaluation.
It will be used in Mask-RCNN.
"""
assert 'mask' in results[0]
assert outfile.endswith('.json')
from pycocotools.coco import COCO
......@@ -164,6 +168,52 @@ def mask_eval(results,
cocoapi_eval(outfile, 'segm', coco_gt=coco_gt)
def segm_eval(results, anno_file, outfile, save_only=False):
"""
Format the output of segmentation, category_id and score in mask.josn, and
get mask ap by coco api evaluation. It will be used in instance segmentation
networks, such as: SOLOv2.
"""
assert 'segm' in results[0]
assert outfile.endswith('.json')
from pycocotools.coco import COCO
coco_gt = COCO(anno_file)
clsid2catid = {i: v for i, v in enumerate(coco_gt.getCatIds())}
segm_results = []
for t in results:
im_id = int(t['im_id'][0][0])
segs = t['segm']
for mask in segs:
catid = int(clsid2catid[mask[0]])
masks = mask[1]
mask_score = masks[1]
segm = masks[0]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': catid,
'segmentation': segm,
'score': mask_score
}
segm_results.append(coco_res)
if len(segm_results) == 0:
logger.warning("The number of valid mask detected is zero.\n \
Please use reasonable model and check input data.")
return
with open(outfile, 'w') as f:
json.dump(segm_results, f)
if save_only:
logger.info('The mask result is saved to {} and do not '
'evaluate the mAP.'.format(outfile))
return
map_stats = cocoapi_eval(outfile, 'segm', coco_gt=coco_gt)
return map_stats
def cocoapi_eval(jsonfile,
style,
coco_gt=None,
......@@ -374,6 +424,43 @@ def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
return segm_res
def segm2out(results, clsid2catid, thresh_binarize=0.5):
import pycocotools.mask as mask_util
segm_res = []
# for each batch
for t in results:
segms = t['segm'][0]
clsid_labels = t['cate_label'][0]
clsid_scores = t['cate_score'][0]
lengths = segms.shape[0]
im_id = int(t['im_id'][0][0])
im_shape = t['im_shape'][0][0]
if lengths == 0 or segms is None:
continue
# for each sample
for i in range(lengths - 1):
im_h = int(im_shape[0])
im_w = int(im_shape[1])
clsid = int(clsid_labels[i])
catid = clsid2catid[clsid]
score = clsid_scores[i]
mask = segms[i]
segm = mask_util.encode(
np.array(
mask[:, :, np.newaxis], order='F'))[0]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': catid,
'segmentation': segm,
'score': score
}
segm_res.append(coco_res)
return segm_res
def expand_boxes(boxes, scale):
"""
Expand an array of boxes by a given scale.
......
......@@ -94,6 +94,25 @@ def clean_res(result, keep_name_list):
return clean_result
def get_masks(result):
import pycocotools.mask as mask_util
if result is None:
return {}
seg_pred = result['segm'][0].astype(np.uint8)
cate_label = result['cate_label'][0].astype(np.int)
cate_score = result['cate_score'][0].astype(np.float)
num_ins = seg_pred.shape[0]
masks = []
for idx in range(num_ins - 1):
cur_mask = seg_pred[idx, ...]
rle = mask_util.encode(
np.array(
cur_mask[:, :, np.newaxis], order='F'))[0]
rst = (rle, cate_score[idx])
masks.append([cate_label[idx], rst])
return masks
def eval_run(exe,
compile_program,
loader,
......@@ -163,11 +182,13 @@ def eval_run(exe,
corner_post_process(res, post_config, cfg.num_classes)
if 'TTFNet' in cfg.architecture:
res['bbox'][1].append([len(res['bbox'][0])])
if 'segm' in res:
res['segm'] = get_masks(res)
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
iter_id += 1
if len(res['bbox'][1]) == 0:
if 'bbox' not in res or len(res['bbox'][1]) == 0:
has_bbox = False
images_num += len(res['bbox'][1][0]) if has_bbox else 1
except (StopIteration, fluid.core.EOFException):
......@@ -198,7 +219,7 @@ def eval_results(results,
"""Evaluation for evaluation program results"""
box_ap_stats = []
if metric == 'COCO':
from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval
from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval, segm_eval
anno_file = dataset.get_anno()
with_background = dataset.with_background
if 'proposal' in results[0]:
......@@ -225,6 +246,14 @@ def eval_results(results,
output = os.path.join(output_directory, 'mask.json')
mask_eval(
results, anno_file, output, resolution, save_only=save_only)
if 'segm' in results[0]:
output = 'segm.json'
if output_directory:
output = os.path.join(output_directory, output)
mask_ap_stats = segm_eval(
results, anno_file, output, save_only=save_only)
if len(box_ap_stats) == 0:
box_ap_stats = mask_ap_stats
else:
if 'accum_map' in results[-1]:
res = np.mean(results[-1]['accum_map'][0])
......
......@@ -133,9 +133,6 @@ def main():
extra_keys)
sub_eval_prog = sub_eval_prog.clone(True)
#if 'weights' in cfg:
# checkpoint.load_params(exe, sub_eval_prog, cfg.weights)
# load model
exe.run(startup_prog)
if 'weights' in cfg:
......@@ -147,7 +144,6 @@ def main():
results = eval_run(exe, compile_program, loader, keys, values, cls, cfg,
sub_eval_prog, sub_keys, sub_values, resolution)
#print(cfg['EvalReader']['dataset'].__dict__)
# evaluation
# if map_type not set, use default 11point, only use in VOC eval
map_type = cfg.map_type if 'map_type' in cfg else '11point'
......
......@@ -46,11 +46,13 @@ TRT_MIN_SUBGRAPH = {
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
'SOLOv2': 3,
}
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
'SOLOv2',
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册