diff --git a/.gitignore b/.gitignore index d91c22d1624ea8f988ec19750f161c2c174fec5a..3c2169494ce2366c2ff2765191c30854ee520812 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ __pycache__/ /lib64/ /output/ /inference_model/ +/dygraph/output_inference/ /parts/ /sdist/ /var/ diff --git a/dygraph/README.md b/dygraph/README.md index c09a6a3616288aeec52bd28702a974fec8d9bf8f..6ed35adf4e09ef80eff415d67eda43bb28ee2c0b 100644 --- a/dygraph/README.md +++ b/dygraph/README.md @@ -10,6 +10,7 @@ - Cascade RCNN - YOLOv3 - SSD +- SOLOv2 扩展特性: diff --git a/dygraph/configs/solov2/README.md b/dygraph/configs/solov2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1c2506fb4f8f9e8b7f20466ab3d679c0ebce5de9 --- /dev/null +++ b/dygraph/configs/solov2/README.md @@ -0,0 +1,37 @@ +# SOLOv2 for instance segmentation + +## Introduction + +SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy and speed of the SOLOv2. + +**Highlights:** + +- Training Time: The training time of the model of `solov2_r50_fpn_1x` on Tesla v100 with 8 GPU is only 10 hours. + +## Model Zoo + +| Detector | Backbone | Multi-scale training | Lr schd | Mask APval | V100 FP32(FPS) | GPU | Download | Configs | +| :-------: | :---------------------: | :-------------------: | :-----: | :--------------------: | :-------------: | :-----: | :---------: | :------------------------: | +| YOLACT++ | R50-FPN | False | 80w iter | 34.1 (test-dev) | 33.5 | Xp | - | - | +| CenterMask | R50-FPN | True | 2x | 36.4 | 13.9 | Xp | - | - | +| CenterMask | V2-99-FPN | True | 3x | 40.2 | 8.9 | Xp | - | - | +| PolarMask | R50-FPN | True | 2x | 30.5 | 9.4 | V100 | - | - | +| BlendMask | R50-FPN | True | 3x | 37.8 | 13.5 | V100 | - | - | +| SOLOv2 (Paper) | R50-FPN | False | 1x | 34.8 | 18.5 | V100 | - | - | +| SOLOv2 (Paper) | X101-DCN-FPN | True | 3x | 42.4 | 5.9 | V100 | - | - | +| SOLOv2 | R50-FPN | False | 1x | 35.5 | 21.9 | V100 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/solov2_r50_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/solov2_r50_fpn_1x_coco.yml) | +| SOLOv2 | R50-FPN | True | 3x | 37.9 | 21.9 | V100 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/solov2_r50_3x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/solov2_r50_fpn_3x_coco.yml) | + +**Notes:** + +- SOLOv2 is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. + +## Citations +``` +@article{wang2020solov2, + title={SOLOv2: Dynamic, Faster and Stronger}, + author={Wang, Xinlong and Zhang, Rufeng and Kong, Tao and Li, Lei and Shen, Chunhua}, + journal={arXiv preprint arXiv:2003.10152}, + year={2020} +} +``` diff --git a/dygraph/configs/solov2/_base_/optimizer_1x.yml b/dygraph/configs/solov2/_base_/optimizer_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..d034482d1e007c4e07fc9b1323b86e04588710bb --- /dev/null +++ b/dygraph/configs/solov2/_base_/optimizer_1x.yml @@ -0,0 +1,19 @@ +epoch: 12 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0. + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/dygraph/configs/solov2/_base_/solov2_r50_fpn.yml b/dygraph/configs/solov2/_base_/solov2_r50_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..e67a6e1b64ff9f58a33ccf5d10005cf3dbb75825 --- /dev/null +++ b/dygraph/configs/solov2/_base_/solov2_r50_fpn.yml @@ -0,0 +1,48 @@ +architecture: SOLOv2 +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar +load_static_weights: True + +SOLOv2: + backbone: ResNet + neck: FPN + solov2_head: SOLOv2Head + mask_head: SOLOv2MaskHead + +ResNet: + # index 0 stands for res2 + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [0,1,2,3] + num_stages: 4 + +FPN: + in_channels: [256, 512, 1024, 2048] + out_channel: 256 + min_level: 0 + max_level: 4 + spatial_scale: [0.25, 0.125, 0.0625, 0.03125] + +SOLOv2Head: + seg_feat_channels: 512 + stacked_convs: 4 + num_grids: [40, 36, 24, 16, 12] + kernel_out_channels: 256 + solov2_loss: SOLOv2Loss + mask_nms: MaskMatrixNMS + +SOLOv2MaskHead: + in_channels: 256 + mid_channels: 128 + out_channels: 256 + start_level: 0 + end_level: 3 + +SOLOv2Loss: + ins_loss_weight: 3.0 + focal_loss_gamma: 2.0 + focal_loss_alpha: 0.25 + +MaskMatrixNMS: + pre_nms_top_n: 500 + post_nms_top_n: 100 diff --git a/dygraph/configs/solov2/_base_/solov2_reader.yml b/dygraph/configs/solov2/_base_/solov2_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..d3a5de110a1da5da11d092eaf4b479f444f6cd13 --- /dev/null +++ b/dygraph/configs/solov2/_base_/solov2_reader.yml @@ -0,0 +1,44 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - DecodeOp: {} + - Poly2Mask: {} + - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - RandomFlipOp: {} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - PermuteOp: {} + batch_transforms: + - PadBatchOp: {pad_to_stride: 32} + - Gt2Solov2TargetOp: {num_grids: [40, 36, 24, 16, 12], + scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]], + coord_sigma: 0.2} + batch_size: 2 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - DecodeOp: {} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - PermuteOp: {} + batch_transforms: + - PadBatchOp: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false + drop_empty: false + + +TestReader: + sample_transforms: + - DecodeOp: {} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - PermuteOp: {} + batch_transforms: + - PadBatchOp: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/dygraph/configs/solov2/solov2_r50_fpn_1x_coco.yml b/dygraph/configs/solov2/solov2_r50_fpn_1x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..f2c15c10b6a36cbb25514b8c01925ab253e0e3df --- /dev/null +++ b/dygraph/configs/solov2/solov2_r50_fpn_1x_coco.yml @@ -0,0 +1,8 @@ +_BASE_: [ + '../_base_/datasets/coco_instance.yml', + '../_base_/runtime.yml', + '_base_/solov2_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/solov2_reader.yml', +] +weights: output/solov2_r50_fpn_1x_coco/model_final diff --git a/dygraph/configs/solov2/solov2_r50_fpn_3x_coco.yml b/dygraph/configs/solov2/solov2_r50_fpn_3x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..f0ac682ac8b9728291d4dbeb5ac3d25d44e48e8e --- /dev/null +++ b/dygraph/configs/solov2/solov2_r50_fpn_3x_coco.yml @@ -0,0 +1,38 @@ +_BASE_: [ + '../_base_/datasets/coco_instance.yml', + '../_base_/runtime.yml', + '_base_/solov2_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/solov2_reader.yml', +] +weights: output/solov2_r50_fpn_3x_coco/model_final +epoch: 36 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0. + steps: 1000 + +TrainReader: + sample_transforms: + - DecodeOp: {} + - Poly2Mask: {} + - RandomResizeOp: {interp: 1, + target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], + keep_ratio: True} + - RandomFlipOp: {} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - PermuteOp: {} + batch_transforms: + - PadBatchOp: {pad_to_stride: 32} + - Gt2Solov2TargetOp: {num_grids: [40, 36, 24, 16, 12], + scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]], + coord_sigma: 0.2} + batch_size: 2 + shuffle: true + drop_last: true diff --git a/dygraph/deploy/python/infer.py b/dygraph/deploy/python/infer.py index abde5fb5a2b12599d1bacb747c4e9939b9376e49..bc8a3158e9b9f9e7140a64bf0d59e14ac1e73ffc 100644 --- a/dygraph/deploy/python/infer.py +++ b/dygraph/deploy/python/infer.py @@ -33,6 +33,7 @@ SUPPORT_MODELS = { 'YOLO', 'RCNN', 'SSD', + 'SOLOv2', } @@ -152,6 +153,83 @@ class Detector(object): return results +class DetectorSOLOv2(Detector): + """ + Args: + config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + use_gpu (bool): whether use gpu + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) + threshold (float): threshold to reserve the result for output. + """ + + def __init__(self, + pred_config, + model_dir, + use_gpu=False, + run_mode='fluid', + threshold=0.5): + self.pred_config = pred_config + self.predictor = load_predictor( + model_dir, + run_mode=run_mode, + min_subgraph_size=self.pred_config.min_subgraph_size, + use_gpu=use_gpu) + + def predict(self, + image, + threshold=0.5, + warmup=0, + repeats=1, + run_benchmark=False): + ''' + Args: + image (str/np.ndarray): path of image/ np.ndarray read by cv2 + threshold (float): threshold of predicted box' score + Returns: + results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + MaskRCNN's results include 'masks': np.ndarray: + shape:[N, class_num, mask_resolution, mask_resolution] + ''' + inputs = self.preprocess(image) + np_label, np_score, np_segms = None, None, None + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + for i in range(warmup): + self.predictor.run() + output_names = self.predictor.get_output_names() + np_label = self.predictor.get_output_handle(output_names[ + 0]).copy_to_cpu() + np_score = self.predictor.get_output_handle(output_names[ + 1]).copy_to_cpu() + np_segms = self.predictor.get_output_handle(output_names[ + 2]).copy_to_cpu() + + t1 = time.time() + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + np_label = self.predictor.get_output_handle(output_names[ + 0]).copy_to_cpu() + np_score = self.predictor.get_output_handle(output_names[ + 1]).copy_to_cpu() + np_segms = self.predictor.get_output_handle(output_names[ + 2]).copy_to_cpu() + t2 = time.time() + ms = (t2 - t1) * 1000.0 / repeats + print("Inference: {} ms per batch image".format(ms)) + + # do not perform postprocess in benchmark mode + results = [] + if not run_benchmark: + return dict(segm=np_segms, label=np_label, score=np_score) + return results + + def create_inputs(im, im_info): """generate input for different model type Args: @@ -362,6 +440,12 @@ def main(): FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode) + if pred_config.arch == 'SOLOv2': + detector = DetectorSOLOv2( + pred_config, + FLAGS.model_dir, + use_gpu=FLAGS.use_gpu, + run_mode=FLAGS.run_mode) # predict from image if FLAGS.image_file != '': predict_image(detector) diff --git a/dygraph/docs/MODEL_ZOO_cn.md b/dygraph/docs/MODEL_ZOO_cn.md index 6b3373e5f1bbb86da30b8c7d6f45be44deb54d87..a38859d58f842dcd7399db7490b7111ed748203c 100644 --- a/dygraph/docs/MODEL_ZOO_cn.md +++ b/dygraph/docs/MODEL_ZOO_cn.md @@ -79,3 +79,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型 | VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/ssd_vgg16_300_240e_voc.yml) | **注意:** SSD使用4GPU训练,训练240个epoch + +### SOLOv2 + +请参考[solov2](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/) diff --git a/dygraph/ppdet/data/source/dataset.py b/dygraph/ppdet/data/source/dataset.py index 7bf5856e19a4a55d5cbace2bbd5ff074e1c8b62f..69e38ef6d1b2c210d8ea054a8b5de63127c77a25 100644 --- a/dygraph/ppdet/data/source/dataset.py +++ b/dygraph/ppdet/data/source/dataset.py @@ -119,6 +119,7 @@ class ImageFolder(DetDataset): sample_num, use_default_label) self._imid2path = {} self.roidbs = None + self.sample_num = sample_num def parse_dataset(self, with_background=True): if not self.roidbs: @@ -144,7 +145,7 @@ class ImageFolder(DetDataset): for image in images: assert image != '' and os.path.isfile(image), \ "Image {} not found".format(image) - if self.sample_num and self.sample_num > 0 and ct >= self.sample_num: + if self.sample_num > 0 and ct >= self.sample_num: break rec = {'im_id': np.array([ct]), 'im_file': image} self._imid2path[ct] = image diff --git a/dygraph/ppdet/data/transform/batch_operator.py b/dygraph/ppdet/data/transform/batch_operator.py index c93a034074cb29135fb7b5c3c5900bdea9b85fc9..f11512f663a3ceb1810d4ec1453830c8b9f23da3 100644 --- a/dygraph/ppdet/data/transform/batch_operator.py +++ b/dygraph/ppdet/data/transform/batch_operator.py @@ -635,6 +635,7 @@ class Gt2Solov2TargetOp(BaseOperator): def __call__(self, samples, context=None): sample_id = 0 + max_ins_num = [0] * len(self.num_grids) for sample in samples: gt_bboxes_raw = sample['gt_bbox'] gt_labels_raw = sample['gt_class'] @@ -667,7 +668,7 @@ class Gt2Solov2TargetOp(BaseOperator): sample['cate_label{}'.format(idx)] = cate_label.flatten() sample['ins_label{}'.format(idx)] = ins_label sample['grid_order{}'.format(idx)] = np.asarray( - [sample_id * num_grid * num_grid + 0]) + [sample_id * num_grid * num_grid + 0], dtype=np.int32) idx += 1 continue gt_bboxes = gt_bboxes_raw[hit_indices] @@ -725,8 +726,8 @@ class Gt2Solov2TargetOp(BaseOperator): 1]] = seg_mask ins_label.append(cur_ins_label) ins_ind_label[label] = True - grid_order.append( - [sample_id * num_grid * num_grid + label]) + grid_order.append(sample_id * num_grid * num_grid + + label) if ins_label == []: ins_label = np.zeros( [1, mask_feat_size[0], mask_feat_size[1]], @@ -735,14 +736,18 @@ class Gt2Solov2TargetOp(BaseOperator): sample['cate_label{}'.format(idx)] = cate_label.flatten() sample['ins_label{}'.format(idx)] = ins_label sample['grid_order{}'.format(idx)] = np.asarray( - [sample_id * num_grid * num_grid + 0]) + [sample_id * num_grid * num_grid + 0], dtype=np.int32) else: ins_label = np.stack(ins_label, axis=0) ins_ind_label_list.append(ins_ind_label) sample['cate_label{}'.format(idx)] = cate_label.flatten() sample['ins_label{}'.format(idx)] = ins_label - sample['grid_order{}'.format(idx)] = np.asarray(grid_order) + sample['grid_order{}'.format(idx)] = np.asarray( + grid_order, dtype=np.int32) assert len(grid_order) > 0 + max_ins_num[idx] = max( + max_ins_num[idx], + sample['ins_label{}'.format(idx)].shape[0]) idx += 1 ins_ind_labels = np.concatenate([ ins_ind_labels_level_img @@ -752,4 +757,28 @@ class Gt2Solov2TargetOp(BaseOperator): sample['fg_num'] = fg_num sample_id += 1 + sample.pop('is_crowd') + sample.pop('gt_class') + sample.pop('gt_bbox') + sample.pop('gt_poly') + sample.pop('gt_segm') + + # padding batch + for data in samples: + for idx in range(len(self.num_grids)): + gt_ins_data = np.zeros( + [ + max_ins_num[idx], + data['ins_label{}'.format(idx)].shape[1], + data['ins_label{}'.format(idx)].shape[2] + ], + dtype=np.uint8) + gt_ins_data[0:data['ins_label{}'.format(idx)].shape[ + 0], :, :] = data['ins_label{}'.format(idx)] + gt_grid_order = np.zeros([max_ins_num[idx]], dtype=np.int32) + gt_grid_order[0:data['grid_order{}'.format(idx)].shape[ + 0]] = data['grid_order{}'.format(idx)] + data['ins_label{}'.format(idx)] = gt_ins_data + data['grid_order{}'.format(idx)] = gt_grid_order + return samples diff --git a/dygraph/ppdet/data/transform/operator.py b/dygraph/ppdet/data/transform/operator.py index e263ddfd4a546fc6c80e99d0431a278e757cd415..37c28524ebf003503863ca09b986143bce514a53 100644 --- a/dygraph/ppdet/data/transform/operator.py +++ b/dygraph/ppdet/data/transform/operator.py @@ -568,7 +568,7 @@ class RandomFlipOp(BaseOperator): if 'semantic' in sample and sample['semantic']: sample['semantic'] = sample['semantic'][:, ::-1] - if 'gt_segm' in sample and sample['gt_segm']: + if 'gt_segm' in sample and sample['gt_segm'].any(): sample['gt_segm'] = sample['gt_segm'][:, :, ::-1] sample['flipped'] = True diff --git a/dygraph/ppdet/engine/trainer.py b/dygraph/ppdet/engine/trainer.py index 4aa4f09dd72df8b39a3f08cb38881206d990eacd..5516edf9c9fa8833a63153c63bfc219a670b5377 100644 --- a/dygraph/ppdet/engine/trainer.py +++ b/dygraph/ppdet/engine/trainer.py @@ -250,10 +250,10 @@ class Trainer(object): # forward self.model.eval() outs = self.model(data) - for key, value in outs.items(): - outs[key] = value.numpy() for key in ['im_shape', 'scale_factor', 'im_id']: outs[key] = data[key] + for key, value in outs.items(): + outs[key] = value.numpy() # FIXME: for more elegent coding if 'mask' in outs and 'bbox' in outs: @@ -275,7 +275,9 @@ class Trainer(object): if 'bbox' in batch_res else None mask_res = batch_res['mask'][start:end] \ if 'mask' in batch_res else None - image = visualize_results(image, bbox_res, mask_res, + segm_res = batch_res['segm'][start:end] \ + if 'segm' in batch_res else None + image = visualize_results(image, bbox_res, mask_res, segm_res, int(outs['im_id']), catid2name, draw_threshold) diff --git a/dygraph/ppdet/metrics/coco_utils.py b/dygraph/ppdet/metrics/coco_utils.py index ba76d5e3af2cdd3c70f1dc3e4ae7af7c229c0b88..c25641d756253b7aa74154602615c77715936ee7 100644 --- a/dygraph/ppdet/metrics/coco_utils.py +++ b/dygraph/ppdet/metrics/coco_utils.py @@ -18,7 +18,7 @@ from __future__ import print_function import os -from ppdet.py_op.post_process import get_det_res, get_seg_res +from ppdet.py_op.post_process import get_det_res, get_seg_res, get_solov2_segm_res from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) @@ -51,6 +51,9 @@ def get_infer_results(outs, catid): infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox_num'], im_id, catid) + if 'segm' in outs: + infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid) + return infer_res diff --git a/dygraph/ppdet/metrics/metrics.py b/dygraph/ppdet/metrics/metrics.py index 8e09b50ee0ae17d7e4ac6b139495494ac95a0a35..fde7e19b860390eddc0cc1650e6044e51f812335 100644 --- a/dygraph/ppdet/metrics/metrics.py +++ b/dygraph/ppdet/metrics/metrics.py @@ -62,7 +62,7 @@ class COCOMetric(Metric): def reset(self): # only bbox and mask evaluation support currently - self.results = {'bbox': [], 'mask': []} + self.results = {'bbox': [], 'mask': [], 'segm': []} self.eval_results = {} def update(self, inputs, outputs): @@ -87,6 +87,8 @@ class COCOMetric(Metric): 'bbox'] if 'bbox' in infer_results else [] self.results['mask'] += infer_results[ 'mask'] if 'mask' in infer_results else [] + self.results['segm'] += infer_results[ + 'segm'] if 'segm' in infer_results else [] def accumulate(self): if len(self.results['bbox']) > 0: @@ -109,6 +111,16 @@ class COCOMetric(Metric): self.eval_results['mask'] = seg_stats sys.stdout.flush() + if len(self.results['segm']) > 0: + with open("segm.json", 'w') as f: + json.dump(self.results['segm'], f) + logger.info('The segm result is saved to segm.json.') + + seg_stats = cocoapi_eval( + 'segm.json', 'segm', anno_file=self.anno_file) + self.eval_results['mask'] = seg_stats + sys.stdout.flush() + def log(self): pass diff --git a/dygraph/ppdet/modeling/architectures/__init__.py b/dygraph/ppdet/modeling/architectures/__init__.py index c2203ab98b991486755decc4ef283bd6c16c2893..7f605680f74c69b3092d9ca7a82144284303417a 100644 --- a/dygraph/ppdet/modeling/architectures/__init__.py +++ b/dygraph/ppdet/modeling/architectures/__init__.py @@ -11,6 +11,7 @@ from . import mask_rcnn from . import yolo from . import cascade_rcnn from . import ssd +from . import solov2 from .meta_arch import * from .faster_rcnn import * @@ -18,3 +19,4 @@ from .mask_rcnn import * from .yolo import * from .cascade_rcnn import * from .ssd import * +from .solov2 import * diff --git a/dygraph/ppdet/modeling/architectures/solov2.py b/dygraph/ppdet/modeling/architectures/solov2.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a34c0d223fb9d8b8b3b85400b80b33e336730a --- /dev/null +++ b/dygraph/ppdet/modeling/architectures/solov2.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle + +from ppdet.core.workspace import register +from .meta_arch import BaseArch + +__all__ = ['SOLOv2'] + + +@register +class SOLOv2(BaseArch): + """ + SOLOv2 network, see https://arxiv.org/abs/2003.10152 + + Args: + backbone (object): an backbone instance + solov2_head (object): an `SOLOv2Head` instance + mask_head (object): an `SOLOv2MaskHead` instance + neck (object): neck of network, such as feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'neck', 'solov2_head', 'mask_head'] + + def __init__(self, backbone, solov2_head, mask_head, neck=None): + super(SOLOv2, self).__init__() + self.backbone = backbone + self.neck = neck + self.solov2_head = solov2_head + self.mask_head = mask_head + + def model_arch(self): + body_feats = self.backbone(self.inputs) + + if self.neck is not None: + body_feats, spatial_scale = self.neck(body_feats) + + self.seg_pred = self.mask_head(body_feats) + + self.cate_pred_list, self.kernel_pred_list = self.solov2_head( + body_feats) + + def get_loss(self, ): + loss = {} + # get gt_ins_labels, gt_cate_labels, etc. + gt_ins_labels, gt_cate_labels, gt_grid_orders = [], [], [] + fg_num = self.inputs['fg_num'] + for i in range(len(self.solov2_head.seg_num_grids)): + ins_label = 'ins_label{}'.format(i) + if ins_label in self.inputs: + gt_ins_labels.append(self.inputs[ins_label]) + cate_label = 'cate_label{}'.format(i) + if cate_label in self.inputs: + gt_cate_labels.append(self.inputs[cate_label]) + grid_order = 'grid_order{}'.format(i) + if grid_order in self.inputs: + gt_grid_orders.append(self.inputs[grid_order]) + + loss_solov2 = self.solov2_head.get_loss( + self.cate_pred_list, self.kernel_pred_list, self.seg_pred, + gt_ins_labels, gt_cate_labels, gt_grid_orders, fg_num) + loss.update(loss_solov2) + total_loss = paddle.add_n(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + + def get_pred(self): + seg_masks, cate_labels, cate_scores, bbox_num = self.solov2_head.get_prediction( + self.cate_pred_list, self.kernel_pred_list, self.seg_pred, + self.inputs['im_shape'], self.inputs['scale_factor']) + outs = { + "segm": seg_masks, + "bbox_num": bbox_num, + 'cate_label': cate_labels, + 'cate_score': cate_scores + } + return outs diff --git a/dygraph/ppdet/modeling/architectures/ssd.py b/dygraph/ppdet/modeling/architectures/ssd.py index d11aa88da7ed46e01c247b51deb0ab7fe99bfe9f..92386db5151df35c7a9a05b1a11011e363eebeb1 100644 --- a/dygraph/ppdet/modeling/architectures/ssd.py +++ b/dygraph/ppdet/modeling/architectures/ssd.py @@ -37,8 +37,7 @@ class SSD(BaseArch): self.anchors) return {"loss": loss} - def get_pred(self, return_numpy=True): - output = {} + def get_pred(self): bbox, bbox_num = self.post_process(self.ssd_head_outs, self.anchors, self.inputs['im_shape'], self.inputs['scale_factor']) diff --git a/dygraph/ppdet/modeling/heads/__init__.py b/dygraph/ppdet/modeling/heads/__init__.py index a0fa75a5ab49e2480db92fe416f50a104298baf5..ebf4148eee1cda7032004242d41fc6103d6fdcfd 100644 --- a/dygraph/ppdet/modeling/heads/__init__.py +++ b/dygraph/ppdet/modeling/heads/__init__.py @@ -18,6 +18,7 @@ from . import mask_head from . import yolo_head from . import roi_extractor from . import ssd_head +from . import solov2_head from .rpn_head import * from .bbox_head import * @@ -25,3 +26,4 @@ from .mask_head import * from .yolo_head import * from .roi_extractor import * from .ssd_head import * +from .solov2_head import * diff --git a/dygraph/ppdet/modeling/heads/solov2_head.py b/dygraph/ppdet/modeling/heads/solov2_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f43271a860ee7418e595f24a62140f7db03728ad --- /dev/null +++ b/dygraph/ppdet/modeling/heads/solov2_head.py @@ -0,0 +1,551 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import Normal, Constant + +from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS +from ppdet.core.workspace import register + +from six.moves import zip +import numpy as np + +__all__ = ['SOLOv2Head'] + + +@register +class SOLOv2MaskHead(nn.Layer): + """ + MaskHead of SOLOv2 + + Args: + in_channels (int): The channel number of input Tensor. + out_channels (int): The channel number of output Tensor. + start_level (int): The position where the input starts. + end_level (int): The position where the input ends. + use_dcn_in_tower (bool): Whether to use dcn in tower or not. + """ + + def __init__(self, + in_channels=256, + mid_channels=128, + out_channels=256, + start_level=0, + end_level=3, + use_dcn_in_tower=False): + super(SOLOv2MaskHead, self).__init__() + assert start_level >= 0 and end_level >= start_level + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = mid_channels + self.use_dcn_in_tower = use_dcn_in_tower + self.range_level = end_level - start_level + 1 + # TODO: add DeformConvNorm + conv_type = [ConvNormLayer] + self.conv_func = conv_type[0] + if self.use_dcn_in_tower: + self.conv_func = conv_type[1] + self.convs_all_levels = [] + for i in range(start_level, end_level + 1): + conv_feat_name = 'mask_feat_head.convs_all_levels.{}'.format(i) + conv_pre_feat = nn.Sequential() + if i == start_level: + conv_pre_feat.add_sublayer( + conv_feat_name + '.conv' + str(i), + self.conv_func( + ch_in=self.in_channels, + ch_out=self.mid_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_name=conv_feat_name + '.conv' + str(i) + '.gn', + name=conv_feat_name + '.conv' + str(i))) + self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat) + self.convs_all_levels.append(conv_pre_feat) + else: + for j in range(i): + ch_in = 0 + if j == 0: + ch_in = self.in_channels + 2 if i == end_level else self.in_channels + else: + ch_in = self.mid_channels + conv_pre_feat.add_sublayer( + conv_feat_name + '.conv' + str(j), + self.conv_func( + ch_in=ch_in, + ch_out=self.mid_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_name=conv_feat_name + '.conv' + str(j) + '.gn', + name=conv_feat_name + '.conv' + str(j))) + conv_pre_feat.add_sublayer( + conv_feat_name + '.conv' + str(j) + 'act', nn.ReLU()) + conv_pre_feat.add_sublayer( + 'upsample' + str(i) + str(j), + nn.Upsample( + scale_factor=2, mode='bilinear')) + self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat) + self.convs_all_levels.append(conv_pre_feat) + + conv_pred_name = 'mask_feat_head.conv_pred.0' + self.conv_pred = self.add_sublayer( + conv_pred_name, + self.conv_func( + ch_in=self.mid_channels, + ch_out=self.out_channels, + filter_size=1, + stride=1, + norm_type='gn', + norm_name=conv_pred_name + '.gn', + name=conv_pred_name)) + + def forward(self, inputs): + """ + Get SOLOv2MaskHead output. + + Args: + inputs(list[Tensor]): feature map from each necks with shape of [N, C, H, W] + Returns: + ins_pred(Tensor): Output of SOLOv2MaskHead head + """ + feat_all_level = F.relu(self.convs_all_levels[0](inputs[0])) + for i in range(1, self.range_level): + input_p = inputs[i] + if i == (self.range_level - 1): + input_feat = input_p + x_range = paddle.linspace( + -1, 1, paddle.shape(input_feat)[-1], dtype='float32') + y_range = paddle.linspace( + -1, 1, paddle.shape(input_feat)[-2], dtype='float32') + y, x = paddle.meshgrid([y_range, x_range]) + x = paddle.unsqueeze(x, [0, 1]) + y = paddle.unsqueeze(y, [0, 1]) + y = paddle.expand( + y, shape=[paddle.shape(input_feat)[0], 1, -1, -1]) + x = paddle.expand( + x, shape=[paddle.shape(input_feat)[0], 1, -1, -1]) + coord_feat = paddle.concat([x, y], axis=1) + input_p = paddle.concat([input_p, coord_feat], axis=1) + feat_all_level = paddle.add(feat_all_level, + self.convs_all_levels[i](input_p)) + ins_pred = F.relu(self.conv_pred(feat_all_level)) + + return ins_pred + + +@register +class SOLOv2Head(nn.Layer): + """ + Head block for SOLOv2 network + + Args: + num_classes (int): Number of output classes. + in_channels (int): Number of input channels. + seg_feat_channels (int): Num_filters of kernel & categroy branch convolution operation. + stacked_convs (int): Times of convolution operation. + num_grids (list[int]): List of feature map grids size. + kernel_out_channels (int): Number of output channels in kernel branch. + dcn_v2_stages (list): Which stage use dcn v2 in tower. It is between [0, stacked_convs). + segm_strides (list[int]): List of segmentation area stride. + solov2_loss (object): SOLOv2Loss instance. + score_threshold (float): Threshold of categroy score. + mask_nms (object): MaskMatrixNMS instance. + """ + __inject__ = ['solov2_loss', 'mask_nms'] + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + in_channels=256, + seg_feat_channels=256, + stacked_convs=4, + num_grids=[40, 36, 24, 16, 12], + kernel_out_channels=256, + dcn_v2_stages=[], + segm_strides=[8, 8, 16, 32, 32], + solov2_loss=None, + score_threshold=0.1, + mask_threshold=0.5, + mask_nms=None): + super(SOLOv2Head, self).__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.seg_num_grids = num_grids + self.cate_out_channels = self.num_classes - 1 + self.seg_feat_channels = seg_feat_channels + self.stacked_convs = stacked_convs + self.kernel_out_channels = kernel_out_channels + self.dcn_v2_stages = dcn_v2_stages + self.segm_strides = segm_strides + self.solov2_loss = solov2_loss + self.mask_nms = mask_nms + self.score_threshold = score_threshold + self.mask_threshold = mask_threshold + + conv_type = [ConvNormLayer] + self.conv_func = conv_type[0] + self.kernel_pred_convs = [] + self.cate_pred_convs = [] + for i in range(self.stacked_convs): + if i in self.dcn_v2_stages: + self.conv_func = conv_type[1] + ch_in = self.in_channels + 2 if i == 0 else self.seg_feat_channels + kernel_conv = self.add_sublayer( + 'bbox_head.kernel_convs.' + str(i), + self.conv_func( + ch_in=ch_in, + ch_out=self.seg_feat_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_name='bbox_head.kernel_convs.{}.gn'.format(i), + name='bbox_head.kernel_convs.{}'.format(i))) + self.kernel_pred_convs.append(kernel_conv) + ch_in = self.in_channels if i == 0 else self.seg_feat_channels + cate_conv = self.add_sublayer( + 'bbox_head.cate_convs.' + str(i), + self.conv_func( + ch_in=ch_in, + ch_out=self.seg_feat_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_name='bbox_head.cate_convs.{}.gn'.format(i), + name='bbox_head.cate_convs.{}'.format(i))) + self.cate_pred_convs.append(cate_conv) + + self.solo_kernel = self.add_sublayer( + 'bbox_head.solo_kernel', + nn.Conv2D( + self.seg_feat_channels, + self.kernel_out_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr( + name="bbox_head.solo_kernel.weight", + initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(name="bbox_head.solo_kernel.bias"))) + self.solo_cate = self.add_sublayer( + 'bbox_head.solo_cate', + nn.Conv2D( + self.seg_feat_channels, + self.cate_out_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr( + name="bbox_head.solo_cate.weight", + initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr( + name="bbox_head.solo_cate.bias", + initializer=Constant( + value=float(-np.log((1 - 0.01) / 0.01)))))) + + def _points_nms(self, heat, kernel_size=2): + hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1) + keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32') + return heat * keep + + def _split_feats(self, feats): + return (F.interpolate( + feats[0], + scale_factor=0.5, + align_corners=False, + align_mode=0, + mode='bilinear'), feats[1], feats[2], feats[3], F.interpolate( + feats[4], + size=paddle.shape(feats[3])[-2:], + mode='bilinear', + align_corners=False, + align_mode=0)) + + def forward(self, input): + """ + Get SOLOv2 head output + + Args: + input (list): List of Tensors, output of backbone or neck stages + Returns: + cate_pred_list (list): Tensors of each category branch layer + kernel_pred_list (list): Tensors of each kernel branch layer + """ + feats = self._split_feats(input) + cate_pred_list = [] + kernel_pred_list = [] + for idx in range(len(self.seg_num_grids)): + cate_pred, kernel_pred = self._get_output_single(feats[idx], idx) + cate_pred_list.append(cate_pred) + kernel_pred_list.append(kernel_pred) + + return cate_pred_list, kernel_pred_list + + def _get_output_single(self, input, idx): + ins_kernel_feat = input + # CoordConv + x_range = paddle.linspace( + -1, 1, paddle.shape(ins_kernel_feat)[-1], dtype='float32') + y_range = paddle.linspace( + -1, 1, paddle.shape(ins_kernel_feat)[-2], dtype='float32') + y, x = paddle.meshgrid([y_range, x_range]) + x = paddle.unsqueeze(x, [0, 1]) + y = paddle.unsqueeze(y, [0, 1]) + y = paddle.expand( + y, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1]) + x = paddle.expand( + x, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1]) + coord_feat = paddle.concat([x, y], axis=1) + ins_kernel_feat = paddle.concat([ins_kernel_feat, coord_feat], axis=1) + + # kernel branch + kernel_feat = ins_kernel_feat + seg_num_grid = self.seg_num_grids[idx] + kernel_feat = F.interpolate( + kernel_feat, + size=[seg_num_grid, seg_num_grid], + mode='bilinear', + align_corners=False, + align_mode=0) + cate_feat = kernel_feat[:, :-2, :, :] + + for kernel_layer in self.kernel_pred_convs: + kernel_feat = F.relu(kernel_layer(kernel_feat)) + kernel_pred = self.solo_kernel(kernel_feat) + # cate branch + for cate_layer in self.cate_pred_convs: + cate_feat = F.relu(cate_layer(cate_feat)) + cate_pred = self.solo_cate(cate_feat) + + if not self.training: + cate_pred = self._points_nms(F.sigmoid(cate_pred), kernel_size=2) + cate_pred = paddle.transpose(cate_pred, [0, 2, 3, 1]) + return cate_pred, kernel_pred + + def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels, + cate_labels, grid_order_list, fg_num): + """ + Get loss of network of SOLOv2. + + Args: + cate_preds (list): Tensor list of categroy branch output. + kernel_preds (list): Tensor list of kernel branch output. + ins_pred (list): Tensor list of instance branch output. + ins_labels (list): List of instance labels pre batch. + cate_labels (list): List of categroy labels pre batch. + grid_order_list (list): List of index in pre grid. + fg_num (int): Number of positive samples in a mini-batch. + Returns: + loss_ins (Tensor): The instance loss Tensor of SOLOv2 network. + loss_cate (Tensor): The category loss Tensor of SOLOv2 network. + """ + batch_size = paddle.shape(grid_order_list[0])[0] + ins_pred_list = [] + for kernel_preds_level, grid_orders_level in zip(kernel_preds, + grid_order_list): + if grid_orders_level.shape[1] == 0: + ins_pred_list.append(None) + continue + grid_orders_level = paddle.reshape(grid_orders_level, [-1]) + reshape_pred = paddle.reshape( + kernel_preds_level, + shape=(paddle.shape(kernel_preds_level)[0], + paddle.shape(kernel_preds_level)[1], -1)) + reshape_pred = paddle.transpose(reshape_pred, [0, 2, 1]) + reshape_pred = paddle.reshape( + reshape_pred, shape=(-1, paddle.shape(reshape_pred)[2])) + gathered_pred = paddle.gather(reshape_pred, index=grid_orders_level) + gathered_pred = paddle.reshape( + gathered_pred, + shape=[batch_size, -1, paddle.shape(gathered_pred)[1]]) + cur_ins_pred = ins_pred + cur_ins_pred = paddle.reshape( + cur_ins_pred, + shape=(paddle.shape(cur_ins_pred)[0], + paddle.shape(cur_ins_pred)[1], -1)) + ins_pred_conv = paddle.matmul(gathered_pred, cur_ins_pred) + cur_ins_pred = paddle.reshape( + ins_pred_conv, + shape=(-1, paddle.shape(ins_pred)[-2], + paddle.shape(ins_pred)[-1])) + ins_pred_list.append(cur_ins_pred) + + num_ins = paddle.sum(fg_num) + cate_preds = [ + paddle.reshape( + paddle.transpose(cate_pred, [0, 2, 3, 1]), + shape=(-1, self.cate_out_channels)) for cate_pred in cate_preds + ] + flatten_cate_preds = paddle.concat(cate_preds) + new_cate_labels = [] + for cate_label in cate_labels: + new_cate_labels.append(paddle.reshape(cate_label, shape=[-1])) + cate_labels = paddle.concat(new_cate_labels) + + loss_ins, loss_cate = self.solov2_loss( + ins_pred_list, ins_labels, flatten_cate_preds, cate_labels, num_ins) + + return {'loss_ins': loss_ins, 'loss_cate': loss_cate} + + def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_shape, + scale_factor): + """ + Get prediction result of SOLOv2 network + + Args: + cate_preds (list): List of Variables, output of categroy branch. + kernel_preds (list): List of Variables, output of kernel branch. + seg_pred (list): List of Variables, output of mask head stages. + im_shape (Variables): [h, w] for input images. + scale_factor (Variables): [scale, scale] for input images. + Returns: + seg_masks (Tensor): The prediction segmentation. + cate_labels (Tensor): The prediction categroy label of each segmentation. + seg_masks (Tensor): The prediction score of each segmentation. + """ + num_levels = len(cate_preds) + featmap_size = paddle.shape(seg_pred)[-2:] + seg_masks_list = [] + cate_labels_list = [] + cate_scores_list = [] + cate_preds = [cate_pred * 1.0 for cate_pred in cate_preds] + kernel_preds = [kernel_pred * 1.0 for kernel_pred in kernel_preds] + # Currently only supports batch size == 1 + for idx in range(1): + cate_pred_list = [ + paddle.reshape( + cate_preds[i][idx], shape=(-1, self.cate_out_channels)) + for i in range(num_levels) + ] + seg_pred_list = seg_pred + kernel_pred_list = [ + paddle.reshape( + paddle.transpose(kernel_preds[i][idx], [1, 2, 0]), + shape=(-1, self.kernel_out_channels)) + for i in range(num_levels) + ] + cate_pred_list = paddle.concat(cate_pred_list, axis=0) + kernel_pred_list = paddle.concat(kernel_pred_list, axis=0) + + seg_masks, cate_labels, cate_scores = self.get_seg_single( + cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size, + im_shape[idx], scale_factor[idx][0]) + bbox_num = paddle.shape(cate_labels)[0] + return seg_masks, cate_labels, cate_scores, bbox_num + + def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size, + im_shape, scale_factor): + h = paddle.cast(im_shape[0], 'int32')[0] + w = paddle.cast(im_shape[1], 'int32')[0] + upsampled_size_out = [featmap_size[0] * 4, featmap_size[1] * 4] + + y = paddle.zeros(shape=paddle.shape(cate_preds), dtype='float32') + inds = paddle.where(cate_preds > self.score_threshold, cate_preds, y) + inds = paddle.nonzero(inds) + if paddle.shape(inds)[0] == 0: + out = paddle.full(shape=[1], fill_value=-1) + return out, out, out + cate_preds = paddle.reshape(cate_preds, shape=[-1]) + # Prevent empty and increase fake data + ind_a = paddle.cast(paddle.shape(kernel_preds)[0], 'int64') + ind_b = paddle.zeros(shape=[1], dtype='int64') + inds_end = paddle.unsqueeze(paddle.concat([ind_a, ind_b]), 0) + inds = paddle.concat([inds, inds_end]) + kernel_preds_end = paddle.ones( + shape=[1, self.kernel_out_channels], dtype='float32') + kernel_preds = paddle.concat([kernel_preds, kernel_preds_end]) + cate_preds = paddle.concat( + [cate_preds, paddle.zeros( + shape=[1], dtype='float32')]) + + # cate_labels & kernel_preds + cate_labels = inds[:, 1] + kernel_preds = paddle.gather(kernel_preds, index=inds[:, 0]) + cate_score_idx = paddle.add(inds[:, 0] * 80, cate_labels) + cate_scores = paddle.gather(cate_preds, index=cate_score_idx) + + size_trans = np.power(self.seg_num_grids, 2) + strides = [] + for _ind in range(len(self.segm_strides)): + strides.append( + paddle.full( + shape=[int(size_trans[_ind])], + fill_value=self.segm_strides[_ind], + dtype="int32")) + strides = paddle.concat(strides) + strides = paddle.gather(strides, index=inds[:, 0]) + + # mask encoding. + kernel_preds = paddle.unsqueeze(kernel_preds, [2, 3]) + seg_preds = F.conv2d(seg_preds, kernel_preds) + seg_preds = F.sigmoid(paddle.squeeze(seg_preds, [0])) + seg_masks = seg_preds > self.mask_threshold + seg_masks = paddle.cast(seg_masks, 'float32') + sum_masks = paddle.sum(seg_masks, axis=[1, 2]) + + y = paddle.zeros(shape=paddle.shape(sum_masks), dtype='float32') + keep = paddle.where(sum_masks > strides, sum_masks, y) + keep = paddle.nonzero(keep) + keep = paddle.squeeze(keep, axis=[1]) + # Prevent empty and increase fake data + keep_other = paddle.concat( + [keep, paddle.cast(paddle.shape(sum_masks)[0] - 1, 'int64')]) + keep_scores = paddle.concat( + [keep, paddle.cast(paddle.shape(sum_masks)[0], 'int64')]) + cate_scores_end = paddle.zeros(shape=[1], dtype='float32') + cate_scores = paddle.concat([cate_scores, cate_scores_end]) + + seg_masks = paddle.gather(seg_masks, index=keep_other) + seg_preds = paddle.gather(seg_preds, index=keep_other) + sum_masks = paddle.gather(sum_masks, index=keep_other) + cate_labels = paddle.gather(cate_labels, index=keep_other) + cate_scores = paddle.gather(cate_scores, index=keep_scores) + + # mask scoring. + seg_mul = paddle.cast(seg_preds * seg_masks, 'float32') + seg_scores = paddle.sum(seg_mul, axis=[1, 2]) / sum_masks + cate_scores *= seg_scores + # Matrix NMS + seg_preds, cate_scores, cate_labels = self.mask_nms( + seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=sum_masks) + ori_shape = im_shape[:2] / scale_factor + 0.5 + ori_shape = paddle.cast(ori_shape, 'int32') + seg_preds = F.interpolate( + paddle.unsqueeze(seg_preds, 0), + size=upsampled_size_out, + mode='bilinear', + align_corners=False, + align_mode=0) + seg_preds = paddle.slice( + seg_preds, axes=[2, 3], starts=[0, 0], ends=[h, w]) + seg_masks = paddle.squeeze( + F.interpolate( + seg_preds, + size=ori_shape[:2], + mode='bilinear', + align_corners=False, + align_mode=0), + axis=[0]) + # TODO: support bool type + seg_masks = paddle.cast(seg_masks > self.mask_threshold, 'int32') + return seg_masks, cate_labels, cate_scores diff --git a/dygraph/ppdet/modeling/layers.py b/dygraph/ppdet/modeling/layers.py index 1cec15f002d0a09017aa5dfc9e99b7b765e12410..e6564f0d695c2bb71fee43bf456cdb7ba760223e 100644 --- a/dygraph/ppdet/modeling/layers.py +++ b/dygraph/ppdet/modeling/layers.py @@ -18,12 +18,18 @@ import numpy as np from numbers import Integral import paddle +import paddle.nn as nn +from paddle import ParamAttr from paddle import to_tensor +from paddle.nn import Conv2D, BatchNorm2D, GroupNorm +import paddle.nn.functional as F +from paddle.nn.initializer import Normal, Constant +from paddle.regularizer import L2Decay + from ppdet.core.workspace import register, serializable from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target from ppdet.py_op.post_process import bbox_post_process from . import ops -import paddle.nn.functional as F def _to_list(l): @@ -32,6 +38,58 @@ def _to_list(l): return [l] +class ConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride, + norm_type='bn', + norm_groups=32, + use_dcn=False, + norm_name=None, + name=None): + super(ConvNormLayer, self).__init__() + assert norm_type in ['bn', 'sync_bn', 'gn'] + + self.conv = Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=1, + weight_attr=ParamAttr( + name=name + "_weight", + initializer=Normal( + mean=0., std=0.01), + learning_rate=1.), + bias_attr=False) + + param_attr = ParamAttr( + name=norm_name + "_scale", + learning_rate=1., + regularizer=L2Decay(0.)) + bias_attr = ParamAttr( + name=norm_name + "_offset", + learning_rate=1., + regularizer=L2Decay(0.)) + if norm_type in ['bn', 'sync_bn']: + self.norm = BatchNorm2D( + ch_out, weight_attr=param_attr, bias_attr=bias_attr) + elif norm_type == 'gn': + self.norm = GroupNorm( + num_groups=norm_groups, + num_channels=ch_out, + weight_attr=param_attr, + bias_attr=bias_attr) + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + return out + + @register @serializable class AnchorGeneratorRPN(object): @@ -651,3 +709,119 @@ class AnchorGrid(object): self._anchor_vars = anchor_vars return self._anchor_vars + + +@register +@serializable +class MaskMatrixNMS(object): + """ + Matrix NMS for multi-class masks. + Args: + update_threshold (float): Updated threshold of categroy score in second time. + pre_nms_top_n (int): Number of total instance to be kept per image before NMS + post_nms_top_n (int): Number of total instance to be kept per image after NMS. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + Input: + seg_preds (Variable): shape (n, h, w), segmentation feature maps + seg_masks (Variable): shape (n, h, w), segmentation feature maps + cate_labels (Variable): shape (n), mask labels in descending order + cate_scores (Variable): shape (n), mask scores in descending order + sum_masks (Variable): a float tensor of the sum of seg_masks + Returns: + Variable: cate_scores, tensors of shape (n) + """ + + def __init__(self, + update_threshold=0.05, + pre_nms_top_n=500, + post_nms_top_n=100, + kernel='gaussian', + sigma=2.0): + super(MaskMatrixNMS, self).__init__() + self.update_threshold = update_threshold + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.kernel = kernel + self.sigma = sigma + + def _sort_score(self, scores, top_num): + if paddle.shape(scores)[0] > top_num: + return paddle.topk(scores, top_num)[1] + else: + return paddle.argsort(scores, descending=True) + + def __call__(self, + seg_preds, + seg_masks, + cate_labels, + cate_scores, + sum_masks=None): + # sort and keep top nms_pre + sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n) + seg_masks = paddle.gather(seg_masks, index=sort_inds) + seg_preds = paddle.gather(seg_preds, index=sort_inds) + sum_masks = paddle.gather(sum_masks, index=sort_inds) + cate_scores = paddle.gather(cate_scores, index=sort_inds) + cate_labels = paddle.gather(cate_labels, index=sort_inds) + + seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1) + # inter. + inter_matrix = paddle.mm(seg_masks, paddle.transpose(seg_masks, [1, 0])) + n_samples = paddle.shape(cate_labels) + # union. + sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples]) + # iou. + iou_matrix = (inter_matrix / ( + sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix)) + iou_matrix = paddle.triu(iou_matrix, diagonal=1) + # label_specific matrix. + cate_labels_x = paddle.expand(cate_labels, shape=[n_samples, n_samples]) + label_matrix = paddle.cast( + (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])), + 'float32') + label_matrix = paddle.triu(label_matrix, diagonal=1) + + # IoU compensation + compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0) + compensate_iou = paddle.expand( + compensate_iou, shape=[n_samples, n_samples]) + compensate_iou = paddle.transpose(compensate_iou, [1, 0]) + + # IoU decay + decay_iou = iou_matrix * label_matrix + + # matrix nms + if self.kernel == 'gaussian': + decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2)) + compensate_matrix = paddle.exp(-1 * self.sigma * + (compensate_iou**2)) + decay_coefficient = paddle.min(decay_matrix / compensate_matrix, + axis=0) + elif self.kernel == 'linear': + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient = paddle.min(decay_matrix, axis=0) + else: + raise NotImplementedError + + # update the score. + cate_scores = cate_scores * decay_coefficient + y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32') + keep = paddle.where(cate_scores >= self.update_threshold, cate_scores, + y) + keep = paddle.nonzero(keep) + keep = paddle.squeeze(keep, axis=[1]) + # Prevent empty and increase fake data + keep = paddle.concat( + [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')]) + + seg_preds = paddle.gather(seg_preds, index=keep) + cate_scores = paddle.gather(cate_scores, index=keep) + cate_labels = paddle.gather(cate_labels, index=keep) + + # sort and keep top_k + sort_inds = self._sort_score(cate_scores, self.post_nms_top_n) + seg_preds = paddle.gather(seg_preds, index=sort_inds) + cate_scores = paddle.gather(cate_scores, index=sort_inds) + cate_labels = paddle.gather(cate_labels, index=sort_inds) + return seg_preds, cate_scores, cate_labels diff --git a/dygraph/ppdet/modeling/losses/__init__.py b/dygraph/ppdet/modeling/losses/__init__.py index b47ab3499c8c9ecf47eb85ceedd9d9a3afebe0fb..94998975736ecce56140feae68ced94f26c724e8 100644 --- a/dygraph/ppdet/modeling/losses/__init__.py +++ b/dygraph/ppdet/modeling/losses/__init__.py @@ -16,8 +16,10 @@ from . import yolo_loss from . import iou_aware_loss from . import iou_loss from . import ssd_loss +from . import solov2_loss from .yolo_loss import * from .iou_aware_loss import * from .iou_loss import * from .ssd_loss import * +from .solov2_loss import * diff --git a/dygraph/ppdet/modeling/losses/solov2_loss.py b/dygraph/ppdet/modeling/losses/solov2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ef97a7707b159cbf9fc2acba42dab58de43721b9 --- /dev/null +++ b/dygraph/ppdet/modeling/losses/solov2_loss.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn.functional as F +from ppdet.core.workspace import register, serializable + +__all__ = ['SOLOv2Loss'] + + +@register +@serializable +class SOLOv2Loss(object): + """ + SOLOv2Loss + Args: + ins_loss_weight (float): Weight of instance loss. + focal_loss_gamma (float): Gamma parameter for focal loss. + focal_loss_alpha (float): Alpha parameter for focal loss. + """ + + def __init__(self, + ins_loss_weight=3.0, + focal_loss_gamma=2.0, + focal_loss_alpha=0.25): + self.ins_loss_weight = ins_loss_weight + self.focal_loss_gamma = focal_loss_gamma + self.focal_loss_alpha = focal_loss_alpha + + def _dice_loss(self, input, target): + input = paddle.reshape(input, shape=(paddle.shape(input)[0], -1)) + target = paddle.reshape(target, shape=(paddle.shape(target)[0], -1)) + a = paddle.sum(input * target, axis=1) + b = paddle.sum(input * input, axis=1) + 0.001 + c = paddle.sum(target * target, axis=1) + 0.001 + d = (2 * a) / (b + c) + return 1 - d + + def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels, + num_ins): + """ + Get loss of network of SOLOv2. + Args: + ins_pred_list (list): Variable list of instance branch output. + ins_label_list (list): List of instance labels pre batch. + cate_preds (list): Concat Variable list of categroy branch output. + cate_labels (list): Concat list of categroy labels pre batch. + num_ins (int): Number of positive samples in a mini-batch. + Returns: + loss_ins (Variable): The instance loss Variable of SOLOv2 network. + loss_cate (Variable): The category loss Variable of SOLOv2 network. + """ + + #1. Ues dice_loss to calculate instance loss + loss_ins = [] + total_weights = paddle.zeros(shape=[1], dtype='float32') + for input, target in zip(ins_pred_list, ins_label_list): + if input is None: + continue + target = paddle.cast(target, 'float32') + target = paddle.reshape( + target, + shape=[-1, paddle.shape(input)[-2], paddle.shape(input)[-1]]) + weights = paddle.cast( + paddle.sum(target, axis=[1, 2]) > 0, 'float32') + input = F.sigmoid(input) + dice_out = paddle.multiply(self._dice_loss(input, target), weights) + total_weights += paddle.sum(weights) + loss_ins.append(dice_out) + loss_ins = paddle.sum(paddle.concat(loss_ins)) / total_weights + loss_ins = loss_ins * self.ins_loss_weight + + #2. Ues sigmoid_focal_loss to calculate category loss + # expand onehot labels + num_classes = cate_preds.shape[-1] + cate_labels_bin = F.one_hot(cate_labels, num_classes=num_classes + 1) + cate_labels_bin = cate_labels_bin[:, 1:] + + loss_cate = F.sigmoid_focal_loss( + cate_preds, + label=cate_labels_bin, + normalizer=num_ins + 1., + gamma=self.focal_loss_gamma, + alpha=self.focal_loss_alpha) + + return loss_ins, loss_cate diff --git a/dygraph/ppdet/modeling/necks/fpn.py b/dygraph/ppdet/modeling/necks/fpn.py index 5565bfdddf664a4e399b60d6043d3899bb9e4f19..780f46013112a99d8f9e757bb5a320b36b84e66e 100644 --- a/dygraph/ppdet/modeling/necks/fpn.py +++ b/dygraph/ppdet/modeling/necks/fpn.py @@ -34,6 +34,9 @@ class FPN(Layer): spatial_scale=[0.25, 0.125, 0.0625, 0.03125]): super(FPN, self).__init__() + self.min_level = min_level + self.max_level = max_level + self.spatial_scale = spatial_scale self.lateral_convs = [] self.fpn_convs = [] fan = out_channel * 3 * 3 @@ -70,10 +73,6 @@ class FPN(Layer): learning_rate=2., regularizer=L2Decay(0.)))) self.fpn_convs.append(fpn_conv) - self.min_level = min_level - self.max_level = max_level - self.spatial_scale = spatial_scale - def forward(self, body_feats): laterals = [] for lvl in range(self.min_level, self.max_level): diff --git a/dygraph/ppdet/py_op/post_process.py b/dygraph/ppdet/py_op/post_process.py index 6216f38cdcda4c714b9ab48b1dcc4e72f895d9d1..2392c2425ce6a7d8993eff9f98f3213c9aaf01e5 100755 --- a/dygraph/ppdet/py_op/post_process.py +++ b/dygraph/ppdet/py_op/post_process.py @@ -184,3 +184,32 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map): } seg_res.append(sg_res) return seg_res + + +def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map): + import pycocotools.mask as mask_util + segm_res = [] + # for each batch + segms = results['segm'].astype(np.uint8) + clsid_labels = results['cate_label'] + clsid_scores = results['cate_score'] + lengths = segms.shape[0] + im_id = int(image_id[0][0]) + if lengths == 0 or segms is None: + return None + # for each sample + for i in range(lengths - 1): + clsid = int(clsid_labels[i]) + 1 + catid = num_id_to_cat_id_map[clsid] + score = float(clsid_scores[i]) + mask = segms[i] + segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res diff --git a/dygraph/ppdet/utils/checkpoint.py b/dygraph/ppdet/utils/checkpoint.py index dbb20eb0a3e1ef20faeef95f9bfa26917dab0048..05ce8171bb890e5cd88dbbd79ec3363942826604 100644 --- a/dygraph/ppdet/utils/checkpoint.py +++ b/dygraph/ppdet/utils/checkpoint.py @@ -92,8 +92,8 @@ def load_weight(model, weight, optimizer=None): param_state_dict = paddle.load(pdparam_path) model.set_dict(param_state_dict) + last_epoch = 0 if optimizer is not None and os.path.exists(path + '.pdopt'): - last_epoch = 0 optim_state_dict = paddle.load(path + '.pdopt') # to slove resume bug, will it be fixed in paddle 2.0 for key in optimizer.state_dict().keys(): @@ -102,8 +102,8 @@ def load_weight(model, weight, optimizer=None): if 'last_epoch' in optim_state_dict: last_epoch = optim_state_dict.pop('last_epoch') optimizer.set_state_dict(optim_state_dict) - return last_epoch - return + + return last_epoch def load_pretrain_weight(model, diff --git a/dygraph/ppdet/utils/visualizer.py b/dygraph/ppdet/utils/visualizer.py index e41db7b366fb42b7ff817fb39910560e76c90e10..5327fef1d2cc92910347ae96f014d79453f70802 100644 --- a/dygraph/ppdet/utils/visualizer.py +++ b/dygraph/ppdet/utils/visualizer.py @@ -19,6 +19,7 @@ from __future__ import unicode_literals import numpy as np from PIL import Image, ImageDraw +import cv2 from .colormap import colormap @@ -28,6 +29,7 @@ __all__ = ['visualize_results'] def visualize_results(image, bbox_res, mask_res, + segm_res, im_id, catid2name, threshold=0.5): @@ -38,6 +40,8 @@ def visualize_results(image, image = draw_bbox(image, im_id, catid2name, bbox_res, threshold) if mask_res is not None: image = draw_mask(image, im_id, mask_res, threshold) + if segm_res is not None: + image = draw_segm(image, im_id, catid2name, segm_res, threshold) return image @@ -106,3 +110,64 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold): draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) return image + + +def draw_segm(image, + im_id, + catid2name, + segms, + threshold, + alpha=0.7, + draw_box=True): + """ + Draw segmentation on image + """ + mask_color_id = 0 + w_ratio = .4 + color_list = colormap(rgb=True) + img_array = np.array(image).astype('float32') + for dt in np.array(segms): + if im_id != dt['image_id']: + continue + segm, score, catid = dt['segmentation'], dt['score'], dt['category_id'] + if score < threshold: + continue + import pycocotools.mask as mask_util + mask = mask_util.decode(segm) * 255 + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + img_array[idx[0], idx[1], :] *= 1.0 - alpha + img_array[idx[0], idx[1], :] += alpha * color_mask + + if not draw_box: + center_y, center_x = ndimage.measurements.center_of_mass(mask) + label_text = "{}".format(catid2name[catid]) + vis_pos = (max(int(center_x) - 10, 0), int(center_y)) + cv2.putText(img_array, label_text, vis_pos, + cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255)) + else: + mask = mask_util.decode(segm) * 255 + sum_x = np.sum(mask, axis=0) + x = np.where(sum_x > 0.5)[0] + sum_y = np.sum(mask, axis=1) + y = np.where(sum_y > 0.5)[0] + x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1] + cv2.rectangle(img_array, (x0, y0), (x1, y1), + tuple(color_mask.astype('int32').tolist()), 1) + bbox_text = '%s %.2f' % (catid2name[catid], score) + t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0] + cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0], + y0 - t_size[1] - 3), + tuple(color_mask.astype('int32').tolist()), -1) + cv2.putText( + img_array, + bbox_text, (x0, y0 - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, (0, 0, 0), + 1, + lineType=cv2.LINE_AA) + + return Image.fromarray(img_array.astype('uint8')) diff --git a/tools/configure.py b/tools/configure.py index 6ac1e40ffb146cc621a30e5311c562fdecb5c08b..fdf826a5521e080765fc30cffc5a4cefa1a9c56b 100644 --- a/tools/configure.py +++ b/tools/configure.py @@ -96,9 +96,8 @@ def list_modules(**kwargs): print("") max_len = max([len(mod.name) for mod in modules]) for mod in modules: - print( - color_tty.green(mod.name.ljust(max_len)), - mod.doc.split('\n')[0]) + print(color_tty.green(mod.name.ljust(max_len)), + mod.doc.split('\n')[0]) print("")