未验证 提交 a633cd0e 编写于 作者: G Guanghua Yu 提交者: GitHub

add solov2 dygraph model (#2041)

* add solov2 dygraph model, test=dygraph
上级 a917efca
......@@ -26,6 +26,7 @@ __pycache__/
/lib64/
/output/
/inference_model/
/dygraph/output_inference/
/parts/
/sdist/
/var/
......
......@@ -10,6 +10,7 @@
- Cascade RCNN
- YOLOv3
- SSD
- SOLOv2
扩展特性:
......
# SOLOv2 for instance segmentation
## Introduction
SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy and speed of the SOLOv2.
**Highlights:**
- Training Time: The training time of the model of `solov2_r50_fpn_1x` on Tesla v100 with 8 GPU is only 10 hours.
## Model Zoo
| Detector | Backbone | Multi-scale training | Lr schd | Mask AP<sup>val</sup> | V100 FP32(FPS) | GPU | Download | Configs |
| :-------: | :---------------------: | :-------------------: | :-----: | :--------------------: | :-------------: | :-----: | :---------: | :------------------------: |
| YOLACT++ | R50-FPN | False | 80w iter | 34.1 (test-dev) | 33.5 | Xp | - | - |
| CenterMask | R50-FPN | True | 2x | 36.4 | 13.9 | Xp | - | - |
| CenterMask | V2-99-FPN | True | 3x | 40.2 | 8.9 | Xp | - | - |
| PolarMask | R50-FPN | True | 2x | 30.5 | 9.4 | V100 | - | - |
| BlendMask | R50-FPN | True | 3x | 37.8 | 13.5 | V100 | - | - |
| SOLOv2 (Paper) | R50-FPN | False | 1x | 34.8 | 18.5 | V100 | - | - |
| SOLOv2 (Paper) | X101-DCN-FPN | True | 3x | 42.4 | 5.9 | V100 | - | - |
| SOLOv2 | R50-FPN | False | 1x | 35.5 | 21.9 | V100 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/solov2_r50_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/solov2_r50_fpn_1x_coco.yml) |
| SOLOv2 | R50-FPN | True | 3x | 37.9 | 21.9 | V100 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/solov2_r50_3x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/solov2_r50_fpn_3x_coco.yml) |
**Notes:**
- SOLOv2 is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
## Citations
```
@article{wang2020solov2,
title={SOLOv2: Dynamic, Faster and Stronger},
author={Wang, Xinlong and Zhang, Rufeng and Kong, Tao and Li, Lei and Shen, Chunhua},
journal={arXiv preprint arXiv:2003.10152},
year={2020}
}
```
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
architecture: SOLOv2
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
load_static_weights: True
SOLOv2:
backbone: ResNet
neck: FPN
solov2_head: SOLOv2Head
mask_head: SOLOv2MaskHead
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
FPN:
in_channels: [256, 512, 1024, 2048]
out_channel: 256
min_level: 0
max_level: 4
spatial_scale: [0.25, 0.125, 0.0625, 0.03125]
SOLOv2Head:
seg_feat_channels: 512
stacked_convs: 4
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 256
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
SOLOv2MaskHead:
in_channels: 256
mid_channels: 128
out_channels: 256
start_level: 0
end_level: 3
SOLOv2Loss:
ins_loss_weight: 3.0
focal_loss_gamma: 2.0
focal_loss_alpha: 0.25
MaskMatrixNMS:
pre_nms_top_n: 500
post_nms_top_n: 100
worker_num: 2
TrainReader:
sample_transforms:
- DecodeOp: {}
- Poly2Mask: {}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- RandomFlipOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 32}
- Gt2Solov2TargetOp: {num_grids: [40, 36, 24, 16, 12],
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]],
coord_sigma: 0.2}
batch_size: 2
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
- DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
_BASE_: [
'../_base_/datasets/coco_instance.yml',
'../_base_/runtime.yml',
'_base_/solov2_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/solov2_reader.yml',
]
weights: output/solov2_r50_fpn_1x_coco/model_final
_BASE_: [
'../_base_/datasets/coco_instance.yml',
'../_base_/runtime.yml',
'_base_/solov2_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/solov2_reader.yml',
]
weights: output/solov2_r50_fpn_3x_coco/model_final
epoch: 36
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [24, 33]
- !LinearWarmup
start_factor: 0.
steps: 1000
TrainReader:
sample_transforms:
- DecodeOp: {}
- Poly2Mask: {}
- RandomResizeOp: {interp: 1,
target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]],
keep_ratio: True}
- RandomFlipOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 32}
- Gt2Solov2TargetOp: {num_grids: [40, 36, 24, 16, 12],
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]],
coord_sigma: 0.2}
batch_size: 2
shuffle: true
drop_last: true
......@@ -33,6 +33,7 @@ SUPPORT_MODELS = {
'YOLO',
'RCNN',
'SSD',
'SOLOv2',
}
......@@ -152,6 +153,83 @@ class Detector(object):
return results
class DetectorSOLOv2(Detector):
"""
Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
use_gpu (bool): whether use gpu
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
threshold (float): threshold to reserve the result for output.
"""
def __init__(self,
pred_config,
model_dir,
use_gpu=False,
run_mode='fluid',
threshold=0.5):
self.pred_config = pred_config
self.predictor = load_predictor(
model_dir,
run_mode=run_mode,
min_subgraph_size=self.pred_config.min_subgraph_size,
use_gpu=use_gpu)
def predict(self,
image,
threshold=0.5,
warmup=0,
repeats=1,
run_benchmark=False):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
'''
inputs = self.preprocess(image)
np_label, np_score, np_segms = None, None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
np_label = self.predictor.get_output_handle(output_names[
0]).copy_to_cpu()
np_score = self.predictor.get_output_handle(output_names[
1]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[
2]).copy_to_cpu()
t1 = time.time()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
np_label = self.predictor.get_output_handle(output_names[
0]).copy_to_cpu()
np_score = self.predictor.get_output_handle(output_names[
1]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[
2]).copy_to_cpu()
t2 = time.time()
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
# do not perform postprocess in benchmark mode
results = []
if not run_benchmark:
return dict(segm=np_segms, label=np_label, score=np_score)
return results
def create_inputs(im, im_info):
"""generate input for different model type
Args:
......@@ -362,6 +440,12 @@ def main():
FLAGS.model_dir,
use_gpu=FLAGS.use_gpu,
run_mode=FLAGS.run_mode)
if pred_config.arch == 'SOLOv2':
detector = DetectorSOLOv2(
pred_config,
FLAGS.model_dir,
use_gpu=FLAGS.use_gpu,
run_mode=FLAGS.run_mode)
# predict from image
if FLAGS.image_file != '':
predict_image(detector)
......
......@@ -79,3 +79,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/ssd_vgg16_300_240e_voc.yml) |
**注意:** SSD使用4GPU训练,训练240个epoch
### SOLOv2
请参考[solov2](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/)
......@@ -119,6 +119,7 @@ class ImageFolder(DetDataset):
sample_num, use_default_label)
self._imid2path = {}
self.roidbs = None
self.sample_num = sample_num
def parse_dataset(self, with_background=True):
if not self.roidbs:
......@@ -144,7 +145,7 @@ class ImageFolder(DetDataset):
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num and self.sample_num > 0 and ct >= self.sample_num:
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
......
......@@ -635,6 +635,7 @@ class Gt2Solov2TargetOp(BaseOperator):
def __call__(self, samples, context=None):
sample_id = 0
max_ins_num = [0] * len(self.num_grids)
for sample in samples:
gt_bboxes_raw = sample['gt_bbox']
gt_labels_raw = sample['gt_class']
......@@ -667,7 +668,7 @@ class Gt2Solov2TargetOp(BaseOperator):
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(
[sample_id * num_grid * num_grid + 0])
[sample_id * num_grid * num_grid + 0], dtype=np.int32)
idx += 1
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
......@@ -725,8 +726,8 @@ class Gt2Solov2TargetOp(BaseOperator):
1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(
[sample_id * num_grid * num_grid + label])
grid_order.append(sample_id * num_grid * num_grid +
label)
if ins_label == []:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
......@@ -735,14 +736,18 @@ class Gt2Solov2TargetOp(BaseOperator):
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(
[sample_id * num_grid * num_grid + 0])
[sample_id * num_grid * num_grid + 0], dtype=np.int32)
else:
ins_label = np.stack(ins_label, axis=0)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(grid_order)
sample['grid_order{}'.format(idx)] = np.asarray(
grid_order, dtype=np.int32)
assert len(grid_order) > 0
max_ins_num[idx] = max(
max_ins_num[idx],
sample['ins_label{}'.format(idx)].shape[0])
idx += 1
ins_ind_labels = np.concatenate([
ins_ind_labels_level_img
......@@ -752,4 +757,28 @@ class Gt2Solov2TargetOp(BaseOperator):
sample['fg_num'] = fg_num
sample_id += 1
sample.pop('is_crowd')
sample.pop('gt_class')
sample.pop('gt_bbox')
sample.pop('gt_poly')
sample.pop('gt_segm')
# padding batch
for data in samples:
for idx in range(len(self.num_grids)):
gt_ins_data = np.zeros(
[
max_ins_num[idx],
data['ins_label{}'.format(idx)].shape[1],
data['ins_label{}'.format(idx)].shape[2]
],
dtype=np.uint8)
gt_ins_data[0:data['ins_label{}'.format(idx)].shape[
0], :, :] = data['ins_label{}'.format(idx)]
gt_grid_order = np.zeros([max_ins_num[idx]], dtype=np.int32)
gt_grid_order[0:data['grid_order{}'.format(idx)].shape[
0]] = data['grid_order{}'.format(idx)]
data['ins_label{}'.format(idx)] = gt_ins_data
data['grid_order{}'.format(idx)] = gt_grid_order
return samples
......@@ -568,7 +568,7 @@ class RandomFlipOp(BaseOperator):
if 'semantic' in sample and sample['semantic']:
sample['semantic'] = sample['semantic'][:, ::-1]
if 'gt_segm' in sample and sample['gt_segm']:
if 'gt_segm' in sample and sample['gt_segm'].any():
sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
sample['flipped'] = True
......
......@@ -250,10 +250,10 @@ class Trainer(object):
# forward
self.model.eval()
outs = self.model(data)
for key, value in outs.items():
outs[key] = value.numpy()
for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data[key]
for key, value in outs.items():
outs[key] = value.numpy()
# FIXME: for more elegent coding
if 'mask' in outs and 'bbox' in outs:
......@@ -275,7 +275,9 @@ class Trainer(object):
if 'bbox' in batch_res else None
mask_res = batch_res['mask'][start:end] \
if 'mask' in batch_res else None
image = visualize_results(image, bbox_res, mask_res,
segm_res = batch_res['segm'][start:end] \
if 'segm' in batch_res else None
image = visualize_results(image, bbox_res, mask_res, segm_res,
int(outs['im_id']), catid2name,
draw_threshold)
......
......@@ -18,7 +18,7 @@ from __future__ import print_function
import os
from ppdet.py_op.post_process import get_det_res, get_seg_res
from ppdet.py_op.post_process import get_det_res, get_seg_res, get_solov2_segm_res
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
......@@ -51,6 +51,9 @@ def get_infer_results(outs, catid):
infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox_num'], im_id,
catid)
if 'segm' in outs:
infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid)
return infer_res
......
......@@ -62,7 +62,7 @@ class COCOMetric(Metric):
def reset(self):
# only bbox and mask evaluation support currently
self.results = {'bbox': [], 'mask': []}
self.results = {'bbox': [], 'mask': [], 'segm': []}
self.eval_results = {}
def update(self, inputs, outputs):
......@@ -87,6 +87,8 @@ class COCOMetric(Metric):
'bbox'] if 'bbox' in infer_results else []
self.results['mask'] += infer_results[
'mask'] if 'mask' in infer_results else []
self.results['segm'] += infer_results[
'segm'] if 'segm' in infer_results else []
def accumulate(self):
if len(self.results['bbox']) > 0:
......@@ -109,6 +111,16 @@ class COCOMetric(Metric):
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
if len(self.results['segm']) > 0:
with open("segm.json", 'w') as f:
json.dump(self.results['segm'], f)
logger.info('The segm result is saved to segm.json.')
seg_stats = cocoapi_eval(
'segm.json', 'segm', anno_file=self.anno_file)
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
def log(self):
pass
......
......@@ -11,6 +11,7 @@ from . import mask_rcnn
from . import yolo
from . import cascade_rcnn
from . import ssd
from . import solov2
from .meta_arch import *
from .faster_rcnn import *
......@@ -18,3 +19,4 @@ from .mask_rcnn import *
from .yolo import *
from .cascade_rcnn import *
from .ssd import *
from .solov2 import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
__all__ = ['SOLOv2']
@register
class SOLOv2(BaseArch):
"""
SOLOv2 network, see https://arxiv.org/abs/2003.10152
Args:
backbone (object): an backbone instance
solov2_head (object): an `SOLOv2Head` instance
mask_head (object): an `SOLOv2MaskHead` instance
neck (object): neck of network, such as feature pyramid network instance
"""
__category__ = 'architecture'
__inject__ = ['backbone', 'neck', 'solov2_head', 'mask_head']
def __init__(self, backbone, solov2_head, mask_head, neck=None):
super(SOLOv2, self).__init__()
self.backbone = backbone
self.neck = neck
self.solov2_head = solov2_head
self.mask_head = mask_head
def model_arch(self):
body_feats = self.backbone(self.inputs)
if self.neck is not None:
body_feats, spatial_scale = self.neck(body_feats)
self.seg_pred = self.mask_head(body_feats)
self.cate_pred_list, self.kernel_pred_list = self.solov2_head(
body_feats)
def get_loss(self, ):
loss = {}
# get gt_ins_labels, gt_cate_labels, etc.
gt_ins_labels, gt_cate_labels, gt_grid_orders = [], [], []
fg_num = self.inputs['fg_num']
for i in range(len(self.solov2_head.seg_num_grids)):
ins_label = 'ins_label{}'.format(i)
if ins_label in self.inputs:
gt_ins_labels.append(self.inputs[ins_label])
cate_label = 'cate_label{}'.format(i)
if cate_label in self.inputs:
gt_cate_labels.append(self.inputs[cate_label])
grid_order = 'grid_order{}'.format(i)
if grid_order in self.inputs:
gt_grid_orders.append(self.inputs[grid_order])
loss_solov2 = self.solov2_head.get_loss(
self.cate_pred_list, self.kernel_pred_list, self.seg_pred,
gt_ins_labels, gt_cate_labels, gt_grid_orders, fg_num)
loss.update(loss_solov2)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
def get_pred(self):
seg_masks, cate_labels, cate_scores, bbox_num = self.solov2_head.get_prediction(
self.cate_pred_list, self.kernel_pred_list, self.seg_pred,
self.inputs['im_shape'], self.inputs['scale_factor'])
outs = {
"segm": seg_masks,
"bbox_num": bbox_num,
'cate_label': cate_labels,
'cate_score': cate_scores
}
return outs
......@@ -37,8 +37,7 @@ class SSD(BaseArch):
self.anchors)
return {"loss": loss}
def get_pred(self, return_numpy=True):
output = {}
def get_pred(self):
bbox, bbox_num = self.post_process(self.ssd_head_outs, self.anchors,
self.inputs['im_shape'],
self.inputs['scale_factor'])
......
......@@ -18,6 +18,7 @@ from . import mask_head
from . import yolo_head
from . import roi_extractor
from . import ssd_head
from . import solov2_head
from .rpn_head import *
from .bbox_head import *
......@@ -25,3 +26,4 @@ from .mask_head import *
from .yolo_head import *
from .roi_extractor import *
from .ssd_head import *
from .solov2_head import *
此差异已折叠。
......@@ -18,12 +18,18 @@ import numpy as np
from numbers import Integral
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle import to_tensor
from paddle.nn import Conv2D, BatchNorm2D, GroupNorm
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target
from ppdet.py_op.post_process import bbox_post_process
from . import ops
import paddle.nn.functional as F
def _to_list(l):
......@@ -32,6 +38,58 @@ def _to_list(l):
return [l]
class ConvNormLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size,
stride,
norm_type='bn',
norm_groups=32,
use_dcn=False,
norm_name=None,
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
self.conv = Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
name=name + "_weight",
initializer=Normal(
mean=0., std=0.01),
learning_rate=1.),
bias_attr=False)
param_attr = ParamAttr(
name=norm_name + "_scale",
learning_rate=1.,
regularizer=L2Decay(0.))
bias_attr = ParamAttr(
name=norm_name + "_offset",
learning_rate=1.,
regularizer=L2Decay(0.))
if norm_type in ['bn', 'sync_bn']:
self.norm = BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'gn':
self.norm = GroupNorm(
num_groups=norm_groups,
num_channels=ch_out,
weight_attr=param_attr,
bias_attr=bias_attr)
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
return out
@register
@serializable
class AnchorGeneratorRPN(object):
......@@ -651,3 +709,119 @@ class AnchorGrid(object):
self._anchor_vars = anchor_vars
return self._anchor_vars
@register
@serializable
class MaskMatrixNMS(object):
"""
Matrix NMS for multi-class masks.
Args:
update_threshold (float): Updated threshold of categroy score in second time.
pre_nms_top_n (int): Number of total instance to be kept per image before NMS
post_nms_top_n (int): Number of total instance to be kept per image after NMS.
kernel (str): 'linear' or 'gaussian'.
sigma (float): std in gaussian method.
Input:
seg_preds (Variable): shape (n, h, w), segmentation feature maps
seg_masks (Variable): shape (n, h, w), segmentation feature maps
cate_labels (Variable): shape (n), mask labels in descending order
cate_scores (Variable): shape (n), mask scores in descending order
sum_masks (Variable): a float tensor of the sum of seg_masks
Returns:
Variable: cate_scores, tensors of shape (n)
"""
def __init__(self,
update_threshold=0.05,
pre_nms_top_n=500,
post_nms_top_n=100,
kernel='gaussian',
sigma=2.0):
super(MaskMatrixNMS, self).__init__()
self.update_threshold = update_threshold
self.pre_nms_top_n = pre_nms_top_n
self.post_nms_top_n = post_nms_top_n
self.kernel = kernel
self.sigma = sigma
def _sort_score(self, scores, top_num):
if paddle.shape(scores)[0] > top_num:
return paddle.topk(scores, top_num)[1]
else:
return paddle.argsort(scores, descending=True)
def __call__(self,
seg_preds,
seg_masks,
cate_labels,
cate_scores,
sum_masks=None):
# sort and keep top nms_pre
sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
seg_masks = paddle.gather(seg_masks, index=sort_inds)
seg_preds = paddle.gather(seg_preds, index=sort_inds)
sum_masks = paddle.gather(sum_masks, index=sort_inds)
cate_scores = paddle.gather(cate_scores, index=sort_inds)
cate_labels = paddle.gather(cate_labels, index=sort_inds)
seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
# inter.
inter_matrix = paddle.mm(seg_masks, paddle.transpose(seg_masks, [1, 0]))
n_samples = paddle.shape(cate_labels)
# union.
sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
# iou.
iou_matrix = (inter_matrix / (
sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix))
iou_matrix = paddle.triu(iou_matrix, diagonal=1)
# label_specific matrix.
cate_labels_x = paddle.expand(cate_labels, shape=[n_samples, n_samples])
label_matrix = paddle.cast(
(cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
'float32')
label_matrix = paddle.triu(label_matrix, diagonal=1)
# IoU compensation
compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
compensate_iou = paddle.expand(
compensate_iou, shape=[n_samples, n_samples])
compensate_iou = paddle.transpose(compensate_iou, [1, 0])
# IoU decay
decay_iou = iou_matrix * label_matrix
# matrix nms
if self.kernel == 'gaussian':
decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
compensate_matrix = paddle.exp(-1 * self.sigma *
(compensate_iou**2))
decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
axis=0)
elif self.kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient = paddle.min(decay_matrix, axis=0)
else:
raise NotImplementedError
# update the score.
cate_scores = cate_scores * decay_coefficient
y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
y)
keep = paddle.nonzero(keep)
keep = paddle.squeeze(keep, axis=[1])
# Prevent empty and increase fake data
keep = paddle.concat(
[keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
seg_preds = paddle.gather(seg_preds, index=keep)
cate_scores = paddle.gather(cate_scores, index=keep)
cate_labels = paddle.gather(cate_labels, index=keep)
# sort and keep top_k
sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
seg_preds = paddle.gather(seg_preds, index=sort_inds)
cate_scores = paddle.gather(cate_scores, index=sort_inds)
cate_labels = paddle.gather(cate_labels, index=sort_inds)
return seg_preds, cate_scores, cate_labels
......@@ -16,8 +16,10 @@ from . import yolo_loss
from . import iou_aware_loss
from . import iou_loss
from . import ssd_loss
from . import solov2_loss
from .yolo_loss import *
from .iou_aware_loss import *
from .iou_loss import *
from .ssd_loss import *
from .solov2_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
__all__ = ['SOLOv2Loss']
@register
@serializable
class SOLOv2Loss(object):
"""
SOLOv2Loss
Args:
ins_loss_weight (float): Weight of instance loss.
focal_loss_gamma (float): Gamma parameter for focal loss.
focal_loss_alpha (float): Alpha parameter for focal loss.
"""
def __init__(self,
ins_loss_weight=3.0,
focal_loss_gamma=2.0,
focal_loss_alpha=0.25):
self.ins_loss_weight = ins_loss_weight
self.focal_loss_gamma = focal_loss_gamma
self.focal_loss_alpha = focal_loss_alpha
def _dice_loss(self, input, target):
input = paddle.reshape(input, shape=(paddle.shape(input)[0], -1))
target = paddle.reshape(target, shape=(paddle.shape(target)[0], -1))
a = paddle.sum(input * target, axis=1)
b = paddle.sum(input * input, axis=1) + 0.001
c = paddle.sum(target * target, axis=1) + 0.001
d = (2 * a) / (b + c)
return 1 - d
def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels,
num_ins):
"""
Get loss of network of SOLOv2.
Args:
ins_pred_list (list): Variable list of instance branch output.
ins_label_list (list): List of instance labels pre batch.
cate_preds (list): Concat Variable list of categroy branch output.
cate_labels (list): Concat list of categroy labels pre batch.
num_ins (int): Number of positive samples in a mini-batch.
Returns:
loss_ins (Variable): The instance loss Variable of SOLOv2 network.
loss_cate (Variable): The category loss Variable of SOLOv2 network.
"""
#1. Ues dice_loss to calculate instance loss
loss_ins = []
total_weights = paddle.zeros(shape=[1], dtype='float32')
for input, target in zip(ins_pred_list, ins_label_list):
if input is None:
continue
target = paddle.cast(target, 'float32')
target = paddle.reshape(
target,
shape=[-1, paddle.shape(input)[-2], paddle.shape(input)[-1]])
weights = paddle.cast(
paddle.sum(target, axis=[1, 2]) > 0, 'float32')
input = F.sigmoid(input)
dice_out = paddle.multiply(self._dice_loss(input, target), weights)
total_weights += paddle.sum(weights)
loss_ins.append(dice_out)
loss_ins = paddle.sum(paddle.concat(loss_ins)) / total_weights
loss_ins = loss_ins * self.ins_loss_weight
#2. Ues sigmoid_focal_loss to calculate category loss
# expand onehot labels
num_classes = cate_preds.shape[-1]
cate_labels_bin = F.one_hot(cate_labels, num_classes=num_classes + 1)
cate_labels_bin = cate_labels_bin[:, 1:]
loss_cate = F.sigmoid_focal_loss(
cate_preds,
label=cate_labels_bin,
normalizer=num_ins + 1.,
gamma=self.focal_loss_gamma,
alpha=self.focal_loss_alpha)
return loss_ins, loss_cate
......@@ -34,6 +34,9 @@ class FPN(Layer):
spatial_scale=[0.25, 0.125, 0.0625, 0.03125]):
super(FPN, self).__init__()
self.min_level = min_level
self.max_level = max_level
self.spatial_scale = spatial_scale
self.lateral_convs = []
self.fpn_convs = []
fan = out_channel * 3 * 3
......@@ -70,10 +73,6 @@ class FPN(Layer):
learning_rate=2., regularizer=L2Decay(0.))))
self.fpn_convs.append(fpn_conv)
self.min_level = min_level
self.max_level = max_level
self.spatial_scale = spatial_scale
def forward(self, body_feats):
laterals = []
for lvl in range(self.min_level, self.max_level):
......
......@@ -184,3 +184,32 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
}
seg_res.append(sg_res)
return seg_res
def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map):
import pycocotools.mask as mask_util
segm_res = []
# for each batch
segms = results['segm'].astype(np.uint8)
clsid_labels = results['cate_label']
clsid_scores = results['cate_score']
lengths = segms.shape[0]
im_id = int(image_id[0][0])
if lengths == 0 or segms is None:
return None
# for each sample
for i in range(lengths - 1):
clsid = int(clsid_labels[i]) + 1
catid = num_id_to_cat_id_map[clsid]
score = float(clsid_scores[i])
mask = segms[i]
segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': catid,
'segmentation': segm,
'score': score
}
segm_res.append(coco_res)
return segm_res
......@@ -92,8 +92,8 @@ def load_weight(model, weight, optimizer=None):
param_state_dict = paddle.load(pdparam_path)
model.set_dict(param_state_dict)
last_epoch = 0
if optimizer is not None and os.path.exists(path + '.pdopt'):
last_epoch = 0
optim_state_dict = paddle.load(path + '.pdopt')
# to slove resume bug, will it be fixed in paddle 2.0
for key in optimizer.state_dict().keys():
......@@ -102,8 +102,8 @@ def load_weight(model, weight, optimizer=None):
if 'last_epoch' in optim_state_dict:
last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict)
return last_epoch
return
return last_epoch
def load_pretrain_weight(model,
......
......@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import numpy as np
from PIL import Image, ImageDraw
import cv2
from .colormap import colormap
......@@ -28,6 +29,7 @@ __all__ = ['visualize_results']
def visualize_results(image,
bbox_res,
mask_res,
segm_res,
im_id,
catid2name,
threshold=0.5):
......@@ -38,6 +40,8 @@ def visualize_results(image,
image = draw_bbox(image, im_id, catid2name, bbox_res, threshold)
if mask_res is not None:
image = draw_mask(image, im_id, mask_res, threshold)
if segm_res is not None:
image = draw_segm(image, im_id, catid2name, segm_res, threshold)
return image
......@@ -106,3 +110,64 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold):
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return image
def draw_segm(image,
im_id,
catid2name,
segms,
threshold,
alpha=0.7,
draw_box=True):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = colormap(rgb=True)
img_array = np.array(image).astype('float32')
for dt in np.array(segms):
if im_id != dt['image_id']:
continue
segm, score, catid = dt['segmentation'], dt['score'], dt['category_id']
if score < threshold:
continue
import pycocotools.mask as mask_util
mask = mask_util.decode(segm) * 255
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
img_array[idx[0], idx[1], :] *= 1.0 - alpha
img_array[idx[0], idx[1], :] += alpha * color_mask
if not draw_box:
center_y, center_x = ndimage.measurements.center_of_mass(mask)
label_text = "{}".format(catid2name[catid])
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(img_array, label_text, vis_pos,
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255))
else:
mask = mask_util.decode(segm) * 255
sum_x = np.sum(mask, axis=0)
x = np.where(sum_x > 0.5)[0]
sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(img_array, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (catid2name[catid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0],
y0 - t_size[1] - 3),
tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(
img_array,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(img_array.astype('uint8'))
......@@ -96,9 +96,8 @@ def list_modules(**kwargs):
print("")
max_len = max([len(mod.name) for mod in modules])
for mod in modules:
print(
color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0])
print(color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0])
print("")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册