From b4727677751081b257c6fa23c3c124ab9e5a32a1 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 11 Aug 2022 11:41:35 +0800 Subject: [PATCH] refactor s2anet (#6604) * refactor s2anet to support batch_size > 1 * fix problem of inference * support batch_size > 1 for training * fix empty results * fix dota eval * fix configs of s2anet_head * modify s2anet_spine_1x to 73 mAP --- configs/datasets/dota.yml | 1 + configs/datasets/spine_coco.yml | 1 + configs/dota/_base_/s2anet.yml | 21 +- configs/dota/_base_/s2anet_reader.yml | 44 +- configs/dota/s2anet_1x_spine.yml | 20 +- configs/dota/s2anet_alignconv_2x_dota.yml | 16 - configs/dota/s2anet_conv_2x_dota.yml | 14 +- ppdet/data/source/coco.py | 1 - ppdet/data/transform/__init__.py | 2 + ppdet/data/transform/batch_operators.py | 67 +- ppdet/data/transform/op_helper.py | 69 ++ ppdet/data/transform/operators.py | 64 -- ppdet/data/transform/rotated_operators.py | 487 +++++++++++ ppdet/metrics/map_utils.py | 8 +- ppdet/metrics/metrics.py | 70 +- ppdet/modeling/architectures/s2anet.py | 47 +- ppdet/modeling/heads/s2anet_head.py | 783 +++++------------- ppdet/modeling/layers.py | 75 ++ ppdet/modeling/post_process.py | 107 +-- .../proposal_generator/anchor_generator.py | 113 ++- .../proposal_generator/target_layer.py | 16 +- 21 files changed, 1122 insertions(+), 904 deletions(-) create mode 100644 ppdet/data/transform/rotated_operators.py diff --git a/configs/datasets/dota.yml b/configs/datasets/dota.yml index 5153163d9..2830b829c 100644 --- a/configs/datasets/dota.yml +++ b/configs/datasets/dota.yml @@ -13,6 +13,7 @@ EvalDataset: image_dir: trainval_split/images anno_path: trainval_split/s2anet_trainval_paddle_coco.json dataset_dir: dataset/DOTA_1024_s2anet/ + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_rbox'] TestDataset: !ImageFolder diff --git a/configs/datasets/spine_coco.yml b/configs/datasets/spine_coco.yml index 743743dc5..41cf51e0e 100644 --- a/configs/datasets/spine_coco.yml +++ b/configs/datasets/spine_coco.yml @@ -13,6 +13,7 @@ EvalDataset: image_dir: images anno_path: annotations/valid.json dataset_dir: dataset/spine_coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_rbox'] TestDataset: !ImageFolder diff --git a/configs/dota/_base_/s2anet.yml b/configs/dota/_base_/s2anet.yml index fb6064224..fc8b2e258 100644 --- a/configs/dota/_base_/s2anet.yml +++ b/configs/dota/_base_/s2anet.yml @@ -7,8 +7,7 @@ weights: output/s2anet_r50_fpn_1x_dota/model_final.pdparams S2ANet: backbone: ResNet neck: FPN - s2anet_head: S2ANetHead - s2anet_bbox_post_process: S2ANetBBoxPostProcess + head: S2ANetHead ResNet: depth: 50 @@ -33,23 +32,21 @@ S2ANetHead: stacked_convs: 2 feat_in: 256 feat_out: 256 - num_classes: 15 align_conv_type: 'AlignConv' # AlignConv Conv align_conv_size: 3 use_sigmoid_cls: True - -RBoxAssigner: - pos_iou_thr: 0.5 - neg_iou_thr: 0.4 - min_iou_thr: 0.0 - ignore_iof_thr: -2 - -S2ANetBBoxPostProcess: + reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1] + cls_loss_weight: [1.1, 1.05] nms_pre: 2000 - min_bbox_size: 0.0 nms: name: MultiClassNMS keep_top_k: -1 score_threshold: 0.05 nms_threshold: 0.1 normalized: False + +RBoxAssigner: + pos_iou_thr: 0.5 + neg_iou_thr: 0.4 + min_iou_thr: 0.0 + ignore_iof_thr: -2 diff --git a/configs/dota/_base_/s2anet_reader.yml b/configs/dota/_base_/s2anet_reader.yml index b28dd5aad..36ac1fd68 100644 --- a/configs/dota/_base_/s2anet_reader.yml +++ b/configs/dota/_base_/s2anet_reader.yml @@ -1,41 +1,43 @@ -worker_num: 0 +worker_num: 4 TrainReader: sample_transforms: - - Decode: {} - - Rbox2Poly: {} - # Resize can process rbox - - Resize: {target_size: [1024, 1024], interp: 2, keep_ratio: False} - - RandomFlip: {prob: 0.5} - - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - Permute: {} + - Decode: {} + - Rbox2Poly: {} + - RandomRFlip: {} + - RResize: {target_size: [1024, 1024], keep_ratio: True, interp: 2} + - Poly2RBox: {rbox_type: 'le135'} batch_transforms: - - PadBatch: {pad_to_stride: 32} - batch_size: 1 + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + - PadRGT: {} + - PadBatch: {pad_to_stride: 32} + batch_size: 2 shuffle: true drop_last: true EvalReader: sample_transforms: - - Decode: {} - - Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True} - - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - Permute: {} + - Decode: {} + - RResize: {target_size: [1024, 1024], keep_ratio: True, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} - batch_size: 1 + - PadBatch: {pad_to_stride: 32} + batch_size: 2 shuffle: false drop_last: false + collate_batch: false TestReader: sample_transforms: - - Decode: {} - - Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True} - - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - Permute: {} + - Decode: {} + - Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/dota/s2anet_1x_spine.yml b/configs/dota/s2anet_1x_spine.yml index 5cf215b54..965db4d18 100644 --- a/configs/dota/s2anet_1x_spine.yml +++ b/configs/dota/s2anet_1x_spine.yml @@ -7,23 +7,19 @@ _BASE_: [ ] weights: output/s2anet_1x_spine/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/s2anet_alignconv_2x_dota.pdparams # for 8 card LearningRate: base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [7, 10] + - !LinearWarmup + start_factor: 0.3333333333333333 + epochs: 5 S2ANetHead: - anchor_strides: [8, 16, 32, 64, 128] - anchor_scales: [4] - anchor_ratios: [1.0] - anchor_assign: RBoxAssigner - stacked_convs: 2 - feat_in: 256 - feat_out: 256 - num_classes: 9 - align_conv_type: 'AlignConv' # AlignConv Conv - align_conv_size: 3 - use_sigmoid_cls: True reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.05] cls_loss_weight: [1.05, 1.0] - reg_loss_type: 'l1' diff --git a/configs/dota/s2anet_alignconv_2x_dota.yml b/configs/dota/s2anet_alignconv_2x_dota.yml index a35bd2fe0..f2ecac202 100644 --- a/configs/dota/s2anet_alignconv_2x_dota.yml +++ b/configs/dota/s2anet_alignconv_2x_dota.yml @@ -8,19 +8,3 @@ _BASE_: [ pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams weights: output/s2anet_alignconv_2x_dota/model_final - -S2ANetHead: - anchor_strides: [8, 16, 32, 64, 128] - anchor_scales: [4] - anchor_ratios: [1.0] - anchor_assign: RBoxAssigner - stacked_convs: 2 - feat_in: 256 - feat_out: 256 - num_classes: 15 - align_conv_type: 'AlignConv' # AlignConv Conv - align_conv_size: 3 - use_sigmoid_cls: True - reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1] - cls_loss_weight: [1.1, 1.05] - reg_loss_type: 'l1' diff --git a/configs/dota/s2anet_conv_2x_dota.yml b/configs/dota/s2anet_conv_2x_dota.yml index 795ba242b..c8e0a1b84 100644 --- a/configs/dota/s2anet_conv_2x_dota.yml +++ b/configs/dota/s2anet_conv_2x_dota.yml @@ -16,16 +16,4 @@ ResNet: num_stages: 4 S2ANetHead: - anchor_strides: [8, 16, 32, 64, 128] - anchor_scales: [4] - anchor_ratios: [1.0] - anchor_assign: RBoxAssigner - stacked_convs: 2 - feat_in: 256 - feat_out: 256 - num_classes: 15 - align_conv_type: 'Conv' # AlignConv Conv - align_conv_size: 3 - use_sigmoid_cls: True - reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1] - cls_loss_weight: [1.1, 1.05] + align_conv_type: 'Conv' diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index 1f7c9b7bb..5f226ec0a 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -180,7 +180,6 @@ class COCODataSet(DetDataset): gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) if is_rbox_anno: gt_rbox = np.zeros((num_bbox, 5), dtype=np.float32) - gt_theta = np.zeros((num_bbox, 1), dtype=np.int32) gt_class = np.zeros((num_bbox, 1), dtype=np.int32) is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) gt_poly = [None] * num_bbox diff --git a/ppdet/data/transform/__init__.py b/ppdet/data/transform/__init__.py index fb8a1a449..a9bf0004a 100644 --- a/ppdet/data/transform/__init__.py +++ b/ppdet/data/transform/__init__.py @@ -16,11 +16,13 @@ from . import operators from . import batch_operators from . import keypoint_operators from . import mot_operators +from . import rotated_operators from .operators import * from .batch_operators import * from .keypoint_operators import * from .mot_operators import * +from .rotated_operators import * __all__ = [] __all__ += registered_ops diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 1a25237b4..b3b20278b 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -48,6 +48,7 @@ __all__ = [ 'Gt2GFLTarget', 'Gt2CenterNetTarget', 'PadGT', + 'PadRGT', ] @@ -109,12 +110,6 @@ class PadBatch(BaseOperator): padding_segm[:, :im_h, :im_w] = gt_segm data['gt_segm'] = padding_segm - if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None: - # ploy to rbox - polys = data['gt_rbox2poly'] - rbox = bbox_utils.poly2rbox(polys) - data['gt_rbox'] = rbox - return samples @@ -981,12 +976,6 @@ class PadMaskBatch(BaseOperator): padding_mask[:im_h, :im_w] = 1. data['pad_mask'] = padding_mask - if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None: - # ploy to rbox - polys = data['gt_rbox2poly'] - rbox = bbox_utils.poly2rbox(polys) - data['gt_rbox'] = rbox - return samples @@ -1122,3 +1111,57 @@ class PadGT(BaseOperator): pad_diff[:num_gt] = sample['difficult'] sample['difficult'] = pad_diff return samples + + +@register_op +class PadRGT(BaseOperator): + """ + Pad 0 to `gt_class`, `gt_bbox`, `gt_score`... + The num_max_boxes is the largest for batch. + Args: + return_gt_mask (bool): If true, return `pad_gt_mask`, + 1 means bbox, 0 means no bbox. + """ + + def __init__(self, return_gt_mask=True): + super(PadRGT, self).__init__() + self.return_gt_mask = return_gt_mask + + def pad_field(self, sample, field, num_gt): + name, shape, dtype = field + if name in sample: + pad_v = np.zeros(shape, dtype=dtype) + if num_gt > 0: + pad_v[:num_gt] = sample[name] + sample[name] = pad_v + + def __call__(self, samples, context=None): + num_max_boxes = max([len(s['gt_bbox']) for s in samples]) + for sample in samples: + if self.return_gt_mask: + sample['pad_gt_mask'] = np.zeros( + (num_max_boxes, 1), dtype=np.float32) + if num_max_boxes == 0: + continue + + num_gt = len(sample['gt_bbox']) + pad_gt_class = np.zeros((num_max_boxes, 1), dtype=np.int32) + pad_gt_bbox = np.zeros((num_max_boxes, 4), dtype=np.float32) + if num_gt > 0: + pad_gt_class[:num_gt] = sample['gt_class'] + pad_gt_bbox[:num_gt] = sample['gt_bbox'] + sample['gt_class'] = pad_gt_class + sample['gt_bbox'] = pad_gt_bbox + # pad_gt_mask + if 'pad_gt_mask' in sample: + sample['pad_gt_mask'][:num_gt] = 1 + # gt_score + names = ['gt_score', 'is_crowd', 'difficult', 'gt_poly', 'gt_rbox'] + dims = [1, 1, 1, 8, 5] + dtypes = [np.float32, np.int32, np.int32, np.float32, np.float32] + + for name, dim, dtype in zip(names, dims, dtypes): + self.pad_field(sample, [name, (num_max_boxes, dim), dtype], + num_gt) + + return samples diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index 6c400306d..eeb152541 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -492,3 +492,72 @@ def get_border(border, size): while size - border // i <= border // i: i *= 2 return border // i + + +def norm_angle(angle, range=[-np.pi / 4, np.pi]): + return (angle - range[0]) % range[1] + range[0] + + +def poly2rbox_le135(poly): + """convert poly to rbox [-pi / 4, 3 * pi / 4] + + Args: + poly: [x1, y1, x2, y2, x3, y3, x4, y4] + + Returns: + rbox: [cx, cy, w, h, angle] + """ + poly = np.array(poly[:8], dtype=np.float32) + + pt1 = (poly[0], poly[1]) + pt2 = (poly[2], poly[3]) + pt3 = (poly[4], poly[5]) + pt4 = (poly[6], poly[7]) + + edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[1]) * + (pt1[1] - pt2[1])) + edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[1] - pt3[1]) * + (pt2[1] - pt3[1])) + + width = max(edge1, edge2) + height = min(edge1, edge2) + + rbox_angle = 0 + if edge1 > edge2: + rbox_angle = np.arctan2(float(pt2[1] - pt1[1]), float(pt2[0] - pt1[0])) + elif edge2 >= edge1: + rbox_angle = np.arctan2(float(pt4[1] - pt1[1]), float(pt4[0] - pt1[0])) + + rbox_angle = norm_angle(rbox_angle) + + x_ctr = float(pt1[0] + pt3[0]) / 2 + y_ctr = float(pt1[1] + pt3[1]) / 2 + return x_ctr, y_ctr, width, height, rbox_angle + + +def poly2rbox_oc(poly): + """convert poly to rbox (0, pi / 2] + + Args: + poly: [x1, y1, x2, y2, x3, y3, x4, y4] + + Returns: + rbox: [cx, cy, w, h, angle] + """ + points = np.array(poly, dtype=np.float32).reshape((-1, 2)) + (cx, cy), (w, h), angle = cv2.minAreaRect(points) + # using the new OpenCV Rotated BBox definition since 4.5.1 + # if angle < 0, opencv is older than 4.5.1, angle is in [-90, 0) + if angle < 0: + angle += 90 + w, h = h, w + + # convert angle to [0, 90) + if angle == -0.0: + angle = 0.0 + if angle == 90.0: + angle = 0.0 + w, h = h, w + + angle = angle / 180 * np.pi + return cx, cy, w, h, angle diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 09a87b128..ec4ef2dc9 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -41,7 +41,6 @@ import threading MUTEX = threading.Lock() from ppdet.core.workspace import serializable -from ppdet.modeling import bbox_utils from ..reader import Compose from .op_helper import (satisfy_sample_constraint, filter_and_process, @@ -657,18 +656,6 @@ class RandomFlip(BaseOperator): bbox[:, 2] = width - oldx1 return bbox - def apply_rbox(self, bbox, width): - oldx1 = bbox[:, 0].copy() - oldx2 = bbox[:, 2].copy() - oldx3 = bbox[:, 4].copy() - oldx4 = bbox[:, 6].copy() - bbox[:, 0] = width - oldx1 - bbox[:, 2] = width - oldx2 - bbox[:, 4] = width - oldx3 - bbox[:, 6] = width - oldx4 - bbox = [bbox_utils.get_best_begin_point_single(e) for e in bbox] - return bbox - def apply(self, sample, context=None): """Filp the image and bounding box. Operators: @@ -700,10 +687,6 @@ class RandomFlip(BaseOperator): if 'gt_segm' in sample and sample['gt_segm'].any(): sample['gt_segm'] = sample['gt_segm'][:, :, ::-1] - if 'gt_rbox2poly' in sample and sample['gt_rbox2poly'].any(): - sample['gt_rbox2poly'] = self.apply_rbox(sample['gt_rbox2poly'], - width) - sample['flipped'] = True sample['image'] = im return sample @@ -841,16 +824,6 @@ class Resize(BaseOperator): [im_scale_x, im_scale_y], [resize_w, resize_h]) - # apply rbox - if 'gt_rbox2poly' in sample: - if np.array(sample['gt_rbox2poly']).shape[1] != 8: - logger.warning( - "gt_rbox2poly's length shoule be 8, but actually is {}". - format(len(sample['gt_rbox2poly']))) - sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'], - [im_scale_x, im_scale_y], - [resize_w, resize_h]) - # apply polygon if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2], @@ -2111,33 +2084,6 @@ class Poly2Mask(BaseOperator): return sample -@register_op -class Rbox2Poly(BaseOperator): - """ - Convert rbbox format to poly format. - """ - - def __init__(self): - super(Rbox2Poly, self).__init__() - - def apply(self, sample, context=None): - assert 'gt_rbox' in sample - assert sample['gt_rbox'].shape[1] == 5 - rrects = sample['gt_rbox'] - x_ctr = rrects[:, 0] - y_ctr = rrects[:, 1] - width = rrects[:, 2] - height = rrects[:, 3] - x1 = x_ctr - width / 2.0 - y1 = y_ctr - height / 2.0 - x2 = x_ctr + width / 2.0 - y2 = y_ctr + height / 2.0 - sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1) - polys = bbox_utils.rbox2poly_np(rrects) - sample['gt_rbox2poly'] = polys - return sample - - @register_op class AugmentHSV(BaseOperator): """ @@ -2456,16 +2402,6 @@ class RandomResizeCrop(BaseOperator): [im_scale_x, im_scale_y], [resize_w, resize_h]) - # apply rbox - if 'gt_rbox2poly' in sample: - if np.array(sample['gt_rbox2poly']).shape[1] != 8: - logger.warn( - "gt_rbox2poly's length shoule be 8, but actually is {}". - format(len(sample['gt_rbox2poly']))) - sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'], - [im_scale_x, im_scale_y], - [resize_w, resize_h]) - # apply polygon if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2], diff --git a/ppdet/data/transform/rotated_operators.py b/ppdet/data/transform/rotated_operators.py new file mode 100644 index 000000000..ede34d639 --- /dev/null +++ b/ppdet/data/transform/rotated_operators.py @@ -0,0 +1,487 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +from numbers import Number, Integral + +import cv2 +import numpy as np +import math +import copy + +from .operators import register_op, BaseOperator +from .op_helper import poly2rbox_le135, poly2rbox_oc +from ppdet.modeling import bbox_utils +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + + +@register_op +class RRotate(BaseOperator): + """ Rotate Image, Polygon, Box + + Args: + scale (float): rotate scale + angle (float): rotate angle + fill_value (int, tuple): fill color + auto_bound (bool): whether auto bound or not + """ + + def __init__(self, scale=1.0, angle=0., fill_value=0., auto_bound=True): + super(RRotate, self).__init__() + self.scale = scale + self.angle = angle + self.fill_value = fill_value + self.auto_bound = auto_bound + + def get_rotated_matrix(self, angle, scale, h, w): + center = ((w - 1) * 0.5, (h - 1) * 0.5) + matrix = cv2.getRotationMatrix2D(center, -angle, scale) + # calculate the new size + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + new_w = h * sin + w * cos + new_h = h * cos + w * sin + # calculate offset + n_w = int(np.round(new_w)) + n_h = int(np.round(new_h)) + if self.auto_bound: + ratio = min(w / n_w, h / n_h) + matrix = cv2.getRotationMatrix2D(center, -angle, ratio) + else: + matrix[0, 2] += (new_w - w) * 0.5 + matrix[1, 2] += (new_h - h) * 0.5 + w = n_w + h = n_h + return matrix, h, w + + def get_rect_from_pts(self, pts, h, w): + """ get minimum rectangle of points + """ + assert pts.shape[-1] % 2 == 0, 'the dim of input [pts] is not correct' + min_x, min_y = np.min(pts[:, 0::2], axis=1), np.min(pts[:, 1::2], + axis=1) + max_x, max_y = np.max(pts[:, 0::2], axis=1), np.max(pts[:, 1::2], + axis=1) + min_x, min_y = np.clip(min_x, 0, w), np.clip(min_y, 0, h) + max_x, max_y = np.clip(max_x, 0, w), np.clip(max_y, 0, h) + boxes = np.stack([min_x, min_y, max_x, max_y], axis=-1) + return boxes + + def apply_image(self, image, matrix, h, w): + return cv2.warpAffine( + image, matrix, (w, h), borderValue=self.fill_value) + + def apply_pts(self, pts, matrix, h, w): + assert pts.shape[-1] % 2 == 0, 'the dim of input [pts] is not correct' + # n is number of samples and m is two times the number of points due to (x, y) + _, m = pts.shape + # transpose points + pts_ = pts.reshape(-1, 2).T + # pad 1 to convert the points to homogeneous coordinates + padding = np.ones((1, pts_.shape[1]), pts.dtype) + rotated_pts = np.matmul(matrix, np.concatenate((pts_, padding), axis=0)) + return rotated_pts[:2, :].T.reshape(-1, m) + + def apply(self, sample, context=None): + image = sample['image'] + h, w = image.shape[:2] + matrix, h, w = self.get_rotated_matrix(self.angle, self.scale, h, w) + sample['image'] = self.apply_image(image, matrix, h, w) + polys = sample['gt_poly'] + # TODO: segment or keypoint to be processed + if len(polys) > 0: + pts = self.apply_pts(polys, matrix, h, w) + sample['gt_poly'] = pts + sample['gt_bbox'] = self.get_rect_from_pts(pts, h, w) + + return sample + + +@register_op +class RandomRRotate(BaseOperator): + """ Random Rotate Image + Args: + scale (float, tuple, list): rotate scale + scale_mode (str): mode of scale, [range, value, None] + angle (float, tuple, list): rotate angle + angle_mode (str): mode of angle, [range, value, None] + fill_value (float, tuple, list): fill value + rotate_prob (float): probability of rotation + auto_bound (bool): whether auto bound or not + """ + + def __init__(self, + scale=1.0, + scale_mode=None, + angle=0., + angle_mode=None, + fill_value=0., + rotate_prob=1.0, + auto_bound=True): + super(RandomRRotate, self).__init__() + self.scale = scale + self.scale_mode = scale_mode + self.angle = angle + self.angle_mode = angle_mode + self.fill_value = fill_value + self.rotate_prob = rotate_prob + self.auto_bound = auto_bound + + def get_angle(self, angle, angle_mode): + assert not angle_mode or angle_mode in [ + 'range', 'value' + ], 'angle mode should be in [range, value, None]' + if not angle_mode: + return angle + elif angle_mode == 'range': + low, high = angle + return np.random.rand() * (high - low) + low + elif angle_mode == 'value': + return np.random.choice(angle) + + def get_scale(self, scale, scale_mode): + assert not scale_mode or scale_mode in [ + 'range', 'value' + ], 'scale mode should be in [range, value, None]' + if not scale_mode: + return scale + elif scale_mode == 'range': + low, high = scale + return np.random.rand() * (high - low) + low + elif scale_mode == 'value': + return np.random.choice(scale) + + def apply(self, sample, context=None): + if np.random.rand() > self.rotate_prob: + return sample + + angle = self.get_angle(self.angle, self.angle_mode) + scale = self.get_scale(self.scale, self.scale_mode) + rotator = RRotate(scale, angle, self.fill_value, self.auto_bound) + return rotator(sample) + + +@register_op +class Poly2RBox(BaseOperator): + """ Polygon to Rotated Box, using new OpenCV definition since 4.5.1 + + Args: + filter_threshold (int, float): threshold to filter annotations + filter_mode (str): filter mode, ['area', 'edge'] + rbox_type (str): rbox type, ['le135', 'oc'] + + """ + + def __init__(self, filter_threshold=4, filter_mode=None, rbox_type='le135'): + super(Poly2RBox, self).__init__() + self.filter_fn = lambda size: self.filter(size, filter_threshold, filter_mode) + self.rbox_fn = poly2rbox_le135 if rbox_type == 'le135' else poly2rbox_oc + + def filter(self, size, threshold, mode): + if mode == 'area': + if size[0] * size[1] < threshold: + return True + elif mode == 'edge': + if min(size) < threshold: + return True + return False + + def get_rbox(self, polys): + valid_ids, rboxes, bboxes = [], [], [] + for i, poly in enumerate(polys): + cx, cy, w, h, angle = self.rbox_fn(poly) + if self.filter_fn((w, h)): + continue + rboxes.append(np.array([cx, cy, w, h, angle], dtype=np.float32)) + valid_ids.append(i) + xmin, ymin = min(poly[0::2]), min(poly[1::2]) + xmax, ymax = max(poly[0::2]), max(poly[1::2]) + bboxes.append(np.array([xmin, ymin, xmax, ymax], dtype=np.float32)) + + if len(valid_ids) == 0: + rboxes = np.zeros((0, 5), dtype=np.float32) + bboxes = np.zeros((0, 4), dtype=np.float32) + else: + rboxes = np.stack(rboxes) + bboxes = np.stack(bboxes) + + return rboxes, bboxes, valid_ids + + def apply(self, sample, context=None): + rboxes, bboxes, valid_ids = self.get_rbox(sample['gt_poly']) + sample['gt_rbox'] = rboxes + sample['gt_bbox'] = bboxes + for k in ['gt_class', 'gt_score', 'gt_poly', 'is_crowd', 'difficult']: + if k in sample: + sample[k] = sample[k][valid_ids] + + return sample + + +@register_op +class Poly2Array(BaseOperator): + """ convert gt_poly to np.array for rotated bboxes + """ + + def __init__(self): + super(Poly2Array, self).__init__() + + def apply(self, sample, context=None): + if 'gt_poly' in sample: + logger.info('gt_poly shape: {}'.format(sample['gt_poly'])) + sample['gt_poly'] = np.array( + sample['gt_poly'], dtype=np.float32).reshape((-1, 8)) + + return sample + + +@register_op +class RResize(BaseOperator): + def __init__(self, target_size, keep_ratio, interp=cv2.INTER_LINEAR): + """ + Resize image to target size. if keep_ratio is True, + resize the image's long side to the maximum of target_size + if keep_ratio is False, resize the image to target size(h, w) + Args: + target_size (int|list): image target size + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): the interpolation method + """ + super(RResize, self).__init__() + self.keep_ratio = keep_ratio + self.interp = interp + if not isinstance(target_size, (Integral, Sequence)): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}". + format(type(target_size))) + if isinstance(target_size, Integral): + target_size = [target_size, target_size] + self.target_size = target_size + + def apply_image(self, image, scale): + im_scale_x, im_scale_y = scale + + return cv2.resize( + image, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + + def apply_pts(self, pts, scale, size): + im_scale_x, im_scale_y = scale + resize_w, resize_h = size + pts[:, 0::2] *= im_scale_x + pts[:, 1::2] *= im_scale_y + pts[:, 0::2] = np.clip(pts[:, 0::2], 0, resize_w) + pts[:, 1::2] = np.clip(pts[:, 1::2], 0, resize_h) + return pts + + def apply(self, sample, context=None): + """ Resize the image numpy. + """ + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image type is not numpy.".format(self)) + if len(im.shape) != 3: + raise ImageError('{}: image is not 3-dimensional.'.format(self)) + + # apply image + im_shape = im.shape + if self.keep_ratio: + + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + + im_scale = min(target_size_min / im_size_min, + target_size_max / im_size_max) + + resize_h = im_scale * float(im_shape[0]) + resize_w = im_scale * float(im_shape[1]) + + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / im_shape[0] + im_scale_x = resize_w / im_shape[1] + + im = self.apply_image(sample['image'], [im_scale_x, im_scale_y]) + sample['image'] = im.astype(np.float32) + sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32) + if 'scale_factor' in sample: + scale_factor = sample['scale_factor'] + sample['scale_factor'] = np.asarray( + [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x], + dtype=np.float32) + else: + sample['scale_factor'] = np.asarray( + [im_scale_y, im_scale_x], dtype=np.float32) + + # apply bbox + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + sample['gt_bbox'] = self.apply_pts(sample['gt_bbox'], + [im_scale_x, im_scale_y], + [resize_w, resize_h]) + + # apply polygon + if 'gt_poly' in sample and len(sample['gt_poly']) > 0: + sample['gt_poly'] = self.apply_pts(sample['gt_poly'], + [im_scale_x, im_scale_y], + [resize_w, resize_h]) + + return sample + + +@register_op +class RandomRFlip(BaseOperator): + def __init__(self, prob=0.5): + """ + Args: + prob (float): the probability of flipping image + """ + super(RandomRFlip, self).__init__() + self.prob = prob + if not (isinstance(self.prob, float)): + raise TypeError("{}: input type is invalid.".format(self)) + + def apply_image(self, image): + return image[:, ::-1, :] + + def apply_pts(self, pts, width): + oldx = pts[:, 0::2].copy() + pts[:, 0::2] = width - oldx - 1 + return pts + + def apply(self, sample, context=None): + """Filp the image and bounding box. + Operators: + 1. Flip the image numpy. + 2. Transform the bboxes' x coordinates. + (Must judge whether the coordinates are normalized!) + 3. Transform the segmentations' x coordinates. + (Must judge whether the coordinates are normalized!) + Output: + sample: the image, bounding box and segmentation part + in sample are flipped. + """ + if np.random.uniform(0, 1) < self.prob: + im = sample['image'] + height, width = im.shape[:2] + im = self.apply_image(im) + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + sample['gt_bbox'] = self.apply_pts(sample['gt_bbox'], width) + if 'gt_poly' in sample and len(sample['gt_poly']) > 0: + sample['gt_poly'] = self.apply_pts(sample['gt_poly'], width) + + sample['flipped'] = True + sample['image'] = im + return sample + + +@register_op +class VisibleRBox(BaseOperator): + """ + In debug mode, visualize images according to `gt_box`. + (Currently only supported when not cropping and flipping image.) + """ + + def __init__(self, output_dir='debug'): + super(VisibleRBox, self).__init__() + self.output_dir = output_dir + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + def apply(self, sample, context=None): + image = Image.fromarray(sample['image'].astype(np.uint8)) + out_file_name = '{:012d}.jpg'.format(sample['im_id'][0]) + width = sample['w'] + height = sample['h'] + # gt_poly = sample['gt_rbox'] + gt_poly = sample['gt_poly'] + gt_class = sample['gt_class'] + draw = ImageDraw.Draw(image) + for i in range(gt_poly.shape[0]): + x1, y1, x2, y2, x3, y3, x4, y4 = gt_poly[i] + draw.line( + [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)], + width=2, + fill='green') + # draw label + xmin = min(x1, x2, x3, x4) + ymin = min(y1, y2, y3, y4) + text = str(gt_class[i][0]) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green') + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + if self.is_normalized: + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] * height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] * width + for i in range(gt_keypoint.shape[0]): + keypoint = gt_keypoint[i] + for j in range(int(keypoint.shape[0] / 2)): + x1 = round(keypoint[2 * j]).astype(np.int32) + y1 = round(keypoint[2 * j + 1]).astype(np.int32) + draw.ellipse( + (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green') + save_path = os.path.join(self.output_dir, out_file_name) + image.save(save_path, quality=95) + return sample + + +@register_op +class Rbox2Poly(BaseOperator): + """ + Convert rbbox format to poly format. + """ + + def __init__(self): + super(Rbox2Poly, self).__init__() + + def apply(self, sample, context=None): + assert 'gt_rbox' in sample + assert sample['gt_rbox'].shape[1] == 5 + rrects = sample['gt_rbox'] + x_ctr = rrects[:, 0] + y_ctr = rrects[:, 1] + width = rrects[:, 2] + height = rrects[:, 3] + x1 = x_ctr - width / 2.0 + y1 = y_ctr - height / 2.0 + x2 = x_ctr + width / 2.0 + y2 = y_ctr + height / 2.0 + sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1) + polys = bbox_utils.rbox2poly_np(rrects) + sample['gt_poly'] = polys + return sample diff --git a/ppdet/metrics/map_utils.py b/ppdet/metrics/map_utils.py index 12fb9ba51..534c3b490 100644 --- a/ppdet/metrics/map_utils.py +++ b/ppdet/metrics/map_utils.py @@ -138,8 +138,7 @@ def calc_rbox_iou(pred, gt_rbox): def prune_zero_padding(gt_box, gt_label, difficult=None): valid_cnt = 0 for i in range(len(gt_box)): - if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \ - gt_box[i, 2] == 0 and gt_box[i, 3] == 0: + if (gt_box[i] == 0).all(): break valid_cnt += 1 return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt] @@ -331,8 +330,9 @@ class DetectionMAP(object): num_columns = min(6, len(results_per_category) * 2) results_flatten = list(itertools.chain(*results_per_category)) headers = ['category', 'AP'] * (num_columns // 2) - results_2d = itertools.zip_longest( - *[results_flatten[i::num_columns] for i in range(num_columns)]) + results_2d = itertools.zip_longest(* [ + results_flatten[i::num_columns] for i in range(num_columns) + ]) table_data = [headers] table_data += [result for result in results_2d] table = AsciiTable(table_data) diff --git a/ppdet/metrics/metrics.py b/ppdet/metrics/metrics.py index b20a569a0..ace0944e3 100644 --- a/ppdet/metrics/metrics.py +++ b/ppdet/metrics/metrics.py @@ -347,22 +347,12 @@ class WiderFaceMetric(Metric): class RBoxMetric(Metric): def __init__(self, anno_file, **kwargs): - assert os.path.isfile(anno_file), \ - "anno_file {} not a file".format(anno_file) - assert os.path.exists(anno_file), "anno_file {} not exists".format( - anno_file) self.anno_file = anno_file - self.gt_anno = json.load(open(self.anno_file)) - cats = self.gt_anno['categories'] - self.clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)} - self.catid2clsid = {cat['id']: i for i, cat in enumerate(cats)} - self.catid2name = {cat['id']: cat['name'] for cat in cats} + self.clsid2catid, self.catid2name = get_categories('COCO', anno_file) + self.catid2clsid = {v: k for k, v in self.clsid2catid.items()} self.classwise = kwargs.get('classwise', False) self.output_eval = kwargs.get('output_eval', None) - # TODO: bias should be unified - self.bias = kwargs.get('bias', 0) self.save_prediction_only = kwargs.get('save_prediction_only', False) - self.iou_type = kwargs.get('IouType', 'bbox') self.overlap_thresh = kwargs.get('overlap_thresh', 0.5) self.map_type = kwargs.get('map_type', '11point') self.evaluate_difficult = kwargs.get('evaluate_difficult', False) @@ -379,7 +369,7 @@ class RBoxMetric(Metric): self.reset() def reset(self): - self.result_bbox = [] + self.results = [] self.detection_map.reset() def update(self, inputs, outputs): @@ -389,35 +379,45 @@ class RBoxMetric(Metric): outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v im_id = inputs['im_id'] - outs['im_id'] = im_id.numpy() if isinstance(im_id, - paddle.Tensor) else im_id + im_id = im_id.numpy() if isinstance(im_id, paddle.Tensor) else im_id + outs['im_id'] = im_id - infer_results = get_infer_results( - outs, self.clsid2catid, bias=self.bias) - self.result_bbox += infer_results[ - 'bbox'] if 'bbox' in infer_results else [] - bbox = [b['bbox'] for b in self.result_bbox] - score = [b['score'] for b in self.result_bbox] - label = [b['category_id'] for b in self.result_bbox] - label = [self.catid2clsid[e] for e in label] - gt_box = [ - e['bbox'] for e in self.gt_anno['annotations'] - if e['image_id'] == outs['im_id'] - ] - gt_label = [ - e['category_id'] for e in self.gt_anno['annotations'] - if e['image_id'] == outs['im_id'] - ] - gt_label = [self.catid2clsid[e] for e in gt_label] - self.detection_map.update(bbox, score, label, gt_box, gt_label) + infer_results = get_infer_results(outs, self.clsid2catid) + infer_results = infer_results['bbox'] if 'bbox' in infer_results else [] + self.results += infer_results + if self.save_prediction_only: + return + + gt_boxes = inputs['gt_rbox'] + gt_labels = inputs['gt_class'] + for i in range(len(gt_boxes)): + gt_box = gt_boxes[i].numpy() if isinstance( + gt_boxes[i], paddle.Tensor) else gt_boxes[i] + gt_label = gt_labels[i].numpy() if isinstance( + gt_labels[i], paddle.Tensor) else gt_labels[i] + gt_box, gt_label, _ = prune_zero_padding(gt_box, gt_label) + bbox = [ + res['bbox'] for res in infer_results + if int(res['image_id']) == int(im_id[i]) + ] + score = [ + res['score'] for res in infer_results + if int(res['image_id']) == int(im_id[i]) + ] + label = [ + self.catid2clsid[int(res['category_id'])] + for res in infer_results + if int(res['image_id']) == int(im_id[i]) + ] + self.detection_map.update(bbox, score, label, gt_box, gt_label) def accumulate(self): - if len(self.result_bbox) > 0: + if len(self.results) > 0: output = "bbox.json" if self.output_eval: output = os.path.join(self.output_eval, output) with open(output, 'w') as f: - json.dump(self.result_bbox, f) + json.dump(self.results, f) logger.info('The bbox result is saved to bbox.json.') if self.save_prediction_only: diff --git a/ppdet/modeling/architectures/s2anet.py b/ppdet/modeling/architectures/s2anet.py index ecfc987f9..8fb71e205 100644 --- a/ppdet/modeling/architectures/s2anet.py +++ b/ppdet/modeling/architectures/s2anet.py @@ -26,26 +26,21 @@ __all__ = ['S2ANet'] @register class S2ANet(BaseArch): __category__ = 'architecture' - __inject__ = [ - 's2anet_head', - 's2anet_bbox_post_process', - ] + __inject__ = ['head'] - def __init__(self, backbone, neck, s2anet_head, s2anet_bbox_post_process): + def __init__(self, backbone, neck, head): """ S2ANet, see https://arxiv.org/pdf/2008.09397.pdf Args: backbone (object): backbone instance neck (object): `FPN` instance - s2anet_head (object): `S2ANetHead` instance - s2anet_bbox_post_process (object): `S2ANetBBoxPostProcess` instance + head (object): `Head` instance """ super(S2ANet, self).__init__() self.backbone = backbone self.neck = neck - self.s2anet_head = s2anet_head - self.s2anet_bbox_post_process = s2anet_bbox_post_process + self.s2anet_head = head @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -55,42 +50,28 @@ class S2ANet(BaseArch): out_shape = neck and neck.out_shape or backbone.out_shape kwargs = {'input_shape': out_shape} - s2anet_head = create(cfg['s2anet_head'], **kwargs) - s2anet_bbox_post_process = create(cfg['s2anet_bbox_post_process'], - **kwargs) + head = create(cfg['head'], **kwargs) - return { - 'backbone': backbone, - 'neck': neck, - "s2anet_head": s2anet_head, - "s2anet_bbox_post_process": s2anet_bbox_post_process, - } + return {'backbone': backbone, 'neck': neck, "head": head} def _forward(self): body_feats = self.backbone(self.inputs) if self.neck is not None: body_feats = self.neck(body_feats) - self.s2anet_head(body_feats) if self.training: - loss = self.s2anet_head.get_loss(self.inputs) - total_loss = paddle.add_n(list(loss.values())) - loss.update({'loss': total_loss}) + loss = self.s2anet_head(body_feats, self.inputs) return loss else: - im_shape = self.inputs['im_shape'] - scale_factor = self.inputs['scale_factor'] - nms_pre = self.s2anet_bbox_post_process.nms_pre - pred_scores, pred_bboxes = self.s2anet_head.get_prediction(nms_pre) - + head_outs = self.s2anet_head(body_feats) # post_process - pred_bboxes, bbox_num = self.s2anet_bbox_post_process(pred_scores, - pred_bboxes) + bboxes, bbox_num = self.s2anet_head.get_bboxes(head_outs) # rescale the prediction back to origin image - pred_bboxes = self.s2anet_bbox_post_process.get_pred( - pred_bboxes, bbox_num, im_shape, scale_factor) - + im_shape = self.inputs['im_shape'] + scale_factor = self.inputs['scale_factor'] + bboxes = self.s2anet_head.get_pred(bboxes, bbox_num, im_shape, + scale_factor) # output - output = {'bbox': pred_bboxes, 'bbox_num': bbox_num} + output = {'bbox': bboxes, 'bbox_num': bbox_num} return output def get_loss(self, ): diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py index e17023d67..53b16f5af 100644 --- a/ppdet/modeling/heads/s2anet_head.py +++ b/ppdet/modeling/heads/s2anet_head.py @@ -20,182 +20,14 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.initializer import Normal, Constant from ppdet.core.workspace import register -from ppdet.modeling import ops -from ppdet.modeling import bbox_utils +from ppdet.modeling.bbox_utils import rbox2poly from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner +from ppdet.modeling.proposal_generator.anchor_generator import S2ANetAnchorGenerator +from ppdet.modeling.layers import AlignConv from ..cls_utils import _get_class_default_kwargs import numpy as np -class S2ANetAnchorGenerator(nn.Layer): - """ - AnchorGenerator by paddle - """ - - def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): - super(S2ANetAnchorGenerator, self).__init__() - self.base_size = base_size - self.scales = paddle.to_tensor(scales) - self.ratios = paddle.to_tensor(ratios) - self.scale_major = scale_major - self.ctr = ctr - self.base_anchors = self.gen_base_anchors() - - @property - def num_base_anchors(self): - return self.base_anchors.shape[0] - - def gen_base_anchors(self): - w = self.base_size - h = self.base_size - if self.ctr is None: - x_ctr = 0.5 * (w - 1) - y_ctr = 0.5 * (h - 1) - else: - x_ctr, y_ctr = self.ctr - - h_ratios = paddle.sqrt(self.ratios) - w_ratios = 1 / h_ratios - if self.scale_major: - ws = (w * w_ratios[:] * self.scales[:]).reshape([-1]) - hs = (h * h_ratios[:] * self.scales[:]).reshape([-1]) - else: - ws = (w * self.scales[:] * w_ratios[:]).reshape([-1]) - hs = (h * self.scales[:] * h_ratios[:]).reshape([-1]) - - base_anchors = paddle.stack( - [ - x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), - x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) - ], - axis=-1) - base_anchors = paddle.round(base_anchors) - return base_anchors - - def _meshgrid(self, x, y, row_major=True): - yy, xx = paddle.meshgrid(y, x) - yy = yy.reshape([-1]) - xx = xx.reshape([-1]) - if row_major: - return xx, yy - else: - return yy, xx - - def forward(self, featmap_size, stride=16): - # featmap_size*stride project it to original area - - feat_h = featmap_size[0] - feat_w = featmap_size[1] - shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride - shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride - shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) - shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) - - all_anchors = self.base_anchors[:, :] + shifts[:, :] - all_anchors = all_anchors.reshape([feat_h * feat_w, 4]) - return all_anchors - - def valid_flags(self, featmap_size, valid_size): - feat_h, feat_w = featmap_size - valid_h, valid_w = valid_size - assert valid_h <= feat_h and valid_w <= feat_w - valid_x = paddle.zeros([feat_w], dtype='int32') - valid_y = paddle.zeros([feat_h], dtype='int32') - valid_x[:valid_w] = 1 - valid_y[:valid_h] = 1 - valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) - valid = valid_xx & valid_yy - valid = paddle.reshape(valid, [-1, 1]) - valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1]) - return valid - - -class AlignConv(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size=3, groups=1): - super(AlignConv, self).__init__() - self.kernel_size = kernel_size - self.align_conv = paddle.vision.ops.DeformConv2D( - in_channels, - out_channels, - kernel_size=self.kernel_size, - padding=(self.kernel_size - 1) // 2, - groups=groups, - weight_attr=ParamAttr(initializer=Normal(0, 0.01)), - bias_attr=None) - - @paddle.no_grad() - def get_offset(self, anchors, featmap_size, stride): - """ - Args: - anchors: [M,5] xc,yc,w,h,angle - featmap_size: (feat_h, feat_w) - stride: 8 - Returns: - - """ - anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5) - dtype = anchors.dtype - feat_h = featmap_size[0] - feat_w = featmap_size[1] - pad = (self.kernel_size - 1) // 2 - idx = paddle.arange(-pad, pad + 1, dtype=dtype) - - yy, xx = paddle.meshgrid(idx, idx) - xx = paddle.reshape(xx, [-1]) - yy = paddle.reshape(yy, [-1]) - - # get sampling locations of default conv - xc = paddle.arange(0, feat_w, dtype=dtype) - yc = paddle.arange(0, feat_h, dtype=dtype) - yc, xc = paddle.meshgrid(yc, xc) - - xc = paddle.reshape(xc, [-1, 1]) - yc = paddle.reshape(yc, [-1, 1]) - x_conv = xc + xx - y_conv = yc + yy - - # get sampling locations of anchors - # x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1) - x_ctr = anchors[:, 0] - y_ctr = anchors[:, 1] - w = anchors[:, 2] - h = anchors[:, 3] - a = anchors[:, 4] - - x_ctr = paddle.reshape(x_ctr, [-1, 1]) - y_ctr = paddle.reshape(y_ctr, [-1, 1]) - w = paddle.reshape(w, [-1, 1]) - h = paddle.reshape(h, [-1, 1]) - a = paddle.reshape(a, [-1, 1]) - - x_ctr = x_ctr / stride - y_ctr = y_ctr / stride - w_s = w / stride - h_s = h / stride - cos, sin = paddle.cos(a), paddle.sin(a) - dw, dh = w_s / self.kernel_size, h_s / self.kernel_size - x, y = dw * xx, dh * yy - xr = cos * x - sin * y - yr = sin * x + cos * y - x_anchor, y_anchor = xr + x_ctr, yr + y_ctr - # get offset filed - offset_x = x_anchor - x_conv - offset_y = y_anchor - y_conv - offset = paddle.stack([offset_y, offset_x], axis=-1) - offset = paddle.reshape( - offset, [feat_h * feat_w, self.kernel_size * self.kernel_size * 2]) - offset = paddle.transpose(offset, [1, 0]) - offset = paddle.reshape( - offset, - [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w]) - return offset - - def forward(self, x, refine_anchors, featmap_size, stride): - offset = self.get_offset(refine_anchors, featmap_size, stride) - x = F.relu(self.align_conv(x, offset)) - return x - - @register class S2ANetHead(nn.Layer): """ @@ -216,7 +48,7 @@ class S2ANetHead(nn.Layer): reg_loss_weight (list): loss weight for regression """ __shared__ = ['num_classes'] - __inject__ = ['anchor_assign'] + __inject__ = ['anchor_assign', 'nms'] def __init__(self, stacked_convs=2, @@ -234,7 +66,9 @@ class S2ANetHead(nn.Layer): anchor_assign=_get_class_default_kwargs(RBoxAssigner), reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1], cls_loss_weight=[1.1, 1.05], - reg_loss_type='l1'): + reg_loss_type='l1', + nms_pre=2000, + nms='MultiClassNMS'): super(S2ANetHead, self).__init__() self.stacked_convs = stacked_convs self.feat_in = feat_in @@ -252,7 +86,7 @@ class S2ANetHead(nn.Layer): self.align_conv_size = align_conv_size self.use_sigmoid_cls = use_sigmoid_cls - self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1 + self.cls_out_channels = num_classes if self.use_sigmoid_cls else num_classes + 1 self.sampling = False self.anchor_assign = anchor_assign self.reg_loss_weight = reg_loss_weight @@ -260,7 +94,13 @@ class S2ANetHead(nn.Layer): self.alpha = 1.0 self.beta = 1.0 self.reg_loss_type = reg_loss_type - self.s2anet_head_out = None + self.nms_pre = nms_pre + self.nms = nms + self.fake_bbox = paddle.to_tensor( + np.array( + [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + dtype='float32')) + self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) # anchor self.anchor_generators = [] @@ -403,64 +243,49 @@ class S2ANetHead(nn.Layer): weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), bias_attr=ParamAttr(initializer=Constant(0))) - self.featmap_sizes = [] - self.base_anchors_list = [] - self.refine_anchor_list = [] + def forward(self, feats, targets=None): + fam_reg_list, fam_cls_list = [], [] + odm_reg_list, odm_cls_list = [], [] + num_anchors_list, base_anchors_list, refine_anchors_list = [], [], [] - def forward(self, feats): - fam_reg_branch_list = [] - fam_cls_branch_list = [] + for i, feat in enumerate(feats): + # get shape + B = feat.shape[0] + H, W = paddle.shape(feat)[2], paddle.shape(feat)[3] - odm_reg_branch_list = [] - odm_cls_branch_list = [] + NA = H * W + num_anchors_list.append(NA) - self.featmap_sizes_list = [] - self.base_anchors_list = [] - self.refine_anchor_list = [] - - for feat_idx in range(len(feats)): - feat = feats[feat_idx] fam_cls_feat = self.fam_cls_convs(feat) - fam_cls = self.fam_cls(fam_cls_feat) # [N, CLS, H, W] --> [N, H, W, CLS] - fam_cls = fam_cls.transpose([0, 2, 3, 1]) - fam_cls_reshape = paddle.reshape( - fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels]) - fam_cls_branch_list.append(fam_cls_reshape) + fam_cls = fam_cls.transpose([0, 2, 3, 1]).reshape( + [B, NA, self.cls_out_channels]) + fam_cls_list.append(fam_cls) fam_reg_feat = self.fam_reg_convs(feat) - fam_reg = self.fam_reg(fam_reg_feat) # [N, 5, H, W] --> [N, H, W, 5] - fam_reg = fam_reg.transpose([0, 2, 3, 1]) - fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5]) - fam_reg_branch_list.append(fam_reg_reshape) + fam_reg = fam_reg.transpose([0, 2, 3, 1]).reshape([B, NA, 5]) + fam_reg_list.append(fam_reg) # prepare anchor - featmap_size = (paddle.shape(feat)[2], paddle.shape(feat)[3]) - self.featmap_sizes_list.append(featmap_size) - init_anchors = self.anchor_generators[feat_idx]( - featmap_size, self.anchor_strides[feat_idx]) - - init_anchors = paddle.to_tensor(init_anchors, dtype='float32') - NA = featmap_size[0] * featmap_size[1] - init_anchors = paddle.reshape(init_anchors, [NA, 4]) - init_anchors = self.rect2rbox(init_anchors) - self.base_anchors_list.append(init_anchors) + init_anchors = self.anchor_generators[i]((H, W), + self.anchor_strides[i]) + init_anchors = init_anchors.reshape([1, NA, 5]) + base_anchors_list.append(init_anchors.squeeze(0)) if self.training: refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors) else: refine_anchor = self.bbox_decode(fam_reg, init_anchors) - self.refine_anchor_list.append(refine_anchor) + refine_anchors_list.append(refine_anchor) if self.align_conv_type == 'AlignConv': align_feat = self.align_conv(feat, - refine_anchor.clone(), - featmap_size, - self.anchor_strides[feat_idx]) + refine_anchor.clone(), (H, W), + self.anchor_strides[i]) elif self.align_conv_type == 'DCN': align_offset = self.align_conv_offset(feat) align_feat = self.align_conv(feat, align_offset) @@ -474,39 +299,140 @@ class S2ANetHead(nn.Layer): odm_reg_feat = self.odm_reg_convs(odm_reg_feat) odm_cls_feat = self.odm_cls_convs(odm_cls_feat) - odm_cls_score = self.odm_cls(odm_cls_feat) + odm_cls = self.odm_cls(odm_cls_feat) # [N, CLS, H, W] --> [N, H, W, CLS] - odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1]) - odm_cls_score_shape = odm_cls_score.shape - odm_cls_score_reshape = paddle.reshape(odm_cls_score, [ - odm_cls_score_shape[0], odm_cls_score_shape[1] * - odm_cls_score_shape[2], self.cls_out_channels + odm_cls = odm_cls.transpose([0, 2, 3, 1]).reshape( + [B, NA, self.cls_out_channels]) + odm_cls_list.append(odm_cls) + + odm_reg = self.odm_reg(odm_reg_feat) + # [N, 5, H, W] --> [N, H, W, 5] + odm_reg = odm_reg.transpose([0, 2, 3, 1]).reshape([B, NA, 5]) + odm_reg_list.append(odm_reg) + + if self.training: + return self.get_loss([ + fam_cls_list, fam_reg_list, odm_cls_list, odm_reg_list, + num_anchors_list, base_anchors_list, refine_anchors_list + ], targets) + else: + odm_bboxes_list = [] + for odm_reg, refine_anchor in zip(odm_reg_list, + refine_anchors_list): + odm_bboxes = self.bbox_decode(odm_reg, refine_anchor) + odm_bboxes_list.append(odm_bboxes) + return [odm_bboxes_list, odm_cls_list] + + def get_bboxes(self, head_outs): + perd_bboxes_list, pred_scores_list = head_outs + batch = paddle.shape(pred_scores_list[0])[0] + bboxes, bbox_num = [], [] + for i in range(batch): + pred_scores_per_image = [t[i] for t in pred_scores_list] + pred_bboxes_per_image = [t[i] for t in perd_bboxes_list] + bbox_per_image, bbox_num_per_image = self.get_bboxes_single( + pred_scores_per_image, pred_bboxes_per_image) + bboxes.append(bbox_per_image) + bbox_num.append(bbox_num_per_image) + + bboxes = paddle.concat(bboxes) + bbox_num = paddle.concat(bbox_num) + return bboxes, bbox_num + + def get_pred(self, bboxes, bbox_num, im_shape, scale_factor): + """ + Rescale, clip and filter the bbox from the output of NMS to + get final prediction. + Args: + bboxes(Tensor): bboxes [N, 10] + bbox_num(Tensor): bbox_num + im_shape(Tensor): [1 2] + scale_factor(Tensor): [1 2] + Returns: + bbox_pred(Tensor): The output is the prediction with shape [N, 8] + including labels, scores and bboxes. The size of + bboxes are corresponding to the original image. + """ + origin_shape = paddle.floor(im_shape / scale_factor + 0.5) + + origin_shape_list = [] + scale_factor_list = [] + # scale_factor: scale_y, scale_x + for i in range(bbox_num.shape[0]): + expand_shape = paddle.expand(origin_shape[i:i + 1, :], + [bbox_num[i], 2]) + scale_y, scale_x = scale_factor[i][0], scale_factor[i][1] + scale = paddle.concat([ + scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x, + scale_y ]) + expand_scale = paddle.expand(scale, [bbox_num[i], 8]) + origin_shape_list.append(expand_shape) + scale_factor_list.append(expand_scale) + + origin_shape_list = paddle.concat(origin_shape_list) + scale_factor_list = paddle.concat(scale_factor_list) + + # bboxes: [N, 10], label, score, bbox + pred_label_score = bboxes[:, 0:2] + pred_bbox = bboxes[:, 2:] + + # rescale bbox to original image + pred_bbox = pred_bbox.reshape([-1, 8]) + scaled_bbox = pred_bbox / scale_factor_list + origin_h = origin_shape_list[:, 0] + origin_w = origin_shape_list[:, 1] + + bboxes = scaled_bbox + zeros = paddle.zeros_like(origin_h) + x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros) + y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros) + x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros) + y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros) + x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros) + y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros) + x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros) + y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros) + pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1) + pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1) + return pred_result + + def get_bboxes_single(self, cls_score_list, bbox_pred_list): + mlvl_bboxes = [] + mlvl_scores = [] - odm_cls_branch_list.append(odm_cls_score_reshape) + for cls_score, bbox_pred in zip(cls_score_list, bbox_pred_list): + if self.use_sigmoid_cls: + scores = F.sigmoid(cls_score) + else: + scores = F.softmax(cls_score, axis=-1) - odm_bbox_pred = self.odm_reg(odm_reg_feat) - # [N, 5, H, W] --> [N, H, W, 5] - odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1]) - odm_bbox_pred_reshape = paddle.reshape(odm_bbox_pred, [-1, 5]) - odm_bbox_pred_reshape = paddle.unsqueeze( - odm_bbox_pred_reshape, axis=0) - odm_reg_branch_list.append(odm_bbox_pred_reshape) - - self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list, - odm_cls_branch_list, odm_reg_branch_list) - return self.s2anet_head_out - - def get_prediction(self, nms_pre=2000): - refine_anchors = self.refine_anchor_list - fam_cls_branch_list = self.s2anet_head_out[0] - fam_reg_branch_list = self.s2anet_head_out[1] - odm_cls_branch_list = self.s2anet_head_out[2] - odm_reg_branch_list = self.s2anet_head_out[3] - pred_scores, pred_bboxes = self.get_bboxes( - odm_cls_branch_list, odm_reg_branch_list, refine_anchors, nms_pre, - self.cls_out_channels, self.use_sigmoid_cls) - return pred_scores, pred_bboxes + if scores.shape[0] > self.nms_pre: + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores = paddle.max(scores, axis=1) + else: + max_scores = paddle.max(scores[:, :-1], axis=1) + + topk_val, topk_inds = paddle.topk(max_scores, self.nms_pre) + bbox_pred = paddle.gather(bbox_pred, topk_inds) + scores = paddle.gather(scores, topk_inds) + + mlvl_bboxes.append(bbox_pred) + mlvl_scores.append(scores) + + mlvl_bboxes = paddle.concat(mlvl_bboxes) + mlvl_scores = paddle.concat(mlvl_scores) + + mlvl_polys = rbox2poly(mlvl_bboxes).unsqueeze(0) + mlvl_scores = paddle.transpose(mlvl_scores, [1, 0]).unsqueeze(0) + + bbox, bbox_num, _ = self.nms(mlvl_polys, mlvl_scores) + if bbox.shape[0] <= 0: + bbox = self.fake_bbox + bbox_num = self.fake_bbox_num + + return bbox, bbox_num def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0): """ @@ -523,10 +449,10 @@ class S2ANetHead(nn.Layer): diff - 0.5 * delta) return loss - def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'): + def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='l1'): (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes, pos_inds, neg_inds) = fam_target - fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out + fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list, num_anchors_list = s2anet_head_out fam_cls_losses = [] fam_bbox_losses = [] @@ -535,9 +461,7 @@ class S2ANetHead(nn.Layer): neg_inds) if self.sampling else len(pos_inds) num_total_samples = max(1, num_total_samples) - for idx, feat_size in enumerate(self.featmap_sizes_list): - feat_anchor_num = feat_size[0] * feat_size[1] - + for idx, feat_anchor_num in enumerate(num_anchors_list): # step1: get data feat_labels = labels[st_idx:st_idx + feat_anchor_num] feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num] @@ -594,39 +518,8 @@ class S2ANetHead(nn.Layer): feat_bbox_weights = paddle.to_tensor( feat_bbox_weights, stop_gradient=True) - if reg_loss_type == 'l1': - fam_bbox = fam_bbox * feat_bbox_weights - fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples - elif reg_loss_type == 'iou' or reg_loss_type == 'gwd': - fam_bbox = paddle.sum(fam_bbox, axis=-1) - feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) - try: - from ext_op import rbox_iou - except Exception as e: - print("import custom_ops error, try install ext_op " \ - "following ppdet/ext_op/README.md", e) - sys.stdout.flush() - sys.exit(-1) - # calc iou - fam_bbox_decode = self.delta2rbox(self.base_anchors_list[idx], - fam_bbox_pred) - bbox_gt_bboxes = paddle.to_tensor( - bbox_gt_bboxes, - dtype=fam_bbox_decode.dtype, - place=fam_bbox_decode.place) - bbox_gt_bboxes.stop_gradient = True - iou = rbox_iou(fam_bbox_decode, bbox_gt_bboxes) - iou = paddle.diag(iou) - - if reg_loss_type == 'gwd': - bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx + - feat_anchor_num, :] - fam_bbox_total = self.gwd_loss(fam_bbox_decode, - bbox_gt_bboxes_level) - fam_bbox_total = fam_bbox_total * feat_bbox_weights - fam_bbox_total = paddle.sum( - fam_bbox_total) / num_total_samples - + fam_bbox = fam_bbox * feat_bbox_weights + fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples fam_bbox_losses.append(fam_bbox_total) st_idx += feat_anchor_num @@ -637,10 +530,10 @@ class S2ANetHead(nn.Layer): fam_reg_loss = paddle.add_n(fam_bbox_losses) return fam_cls_loss, fam_reg_loss - def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'): + def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='l1'): (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes, pos_inds, neg_inds) = odm_target - fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out + fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list, num_anchors_list = s2anet_head_out odm_cls_losses = [] odm_bbox_losses = [] @@ -649,9 +542,7 @@ class S2ANetHead(nn.Layer): neg_inds) if self.sampling else len(pos_inds) num_total_samples = max(1, num_total_samples) - for idx, feat_size in enumerate(self.featmap_sizes_list): - feat_anchor_num = feat_size[0] * feat_size[1] - + for idx, feat_anchor_num in enumerate(num_anchors_list): # step1: get data feat_labels = labels[st_idx:st_idx + feat_anchor_num] feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num] @@ -709,38 +600,8 @@ class S2ANetHead(nn.Layer): feat_bbox_weights = paddle.to_tensor( feat_bbox_weights, stop_gradient=True) - if reg_loss_type == 'l1': - odm_bbox = odm_bbox * feat_bbox_weights - odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples - elif reg_loss_type == 'iou' or reg_loss_type == 'gwd': - odm_bbox = paddle.sum(odm_bbox, axis=-1) - feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) - try: - from ext_op import rbox_iou - except Exception as e: - print("import custom_ops error, try install ext_op " \ - "following ppdet/ext_op/README.md", e) - sys.stdout.flush() - sys.exit(-1) - # calc iou - odm_bbox_decode = self.delta2rbox(self.refine_anchor_list[idx], - odm_bbox_pred) - bbox_gt_bboxes = paddle.to_tensor( - bbox_gt_bboxes, - dtype=odm_bbox_decode.dtype, - place=odm_bbox_decode.place) - bbox_gt_bboxes.stop_gradient = True - iou = rbox_iou(odm_bbox_decode, bbox_gt_bboxes) - iou = paddle.diag(iou) - - if reg_loss_type == 'gwd': - bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx + - feat_anchor_num, :] - odm_bbox_total = self.gwd_loss(odm_bbox_decode, - bbox_gt_bboxes_level) - odm_bbox_total = odm_bbox_total * feat_bbox_weights - odm_bbox_total = paddle.sum( - odm_bbox_total) / num_total_samples + odm_bbox = odm_bbox * feat_bbox_weights + odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples odm_bbox_losses.append(odm_bbox_total) st_idx += feat_anchor_num @@ -752,8 +613,9 @@ class S2ANetHead(nn.Layer): odm_reg_loss = paddle.add_n(odm_bbox_losses) return odm_cls_loss, odm_reg_loss - def get_loss(self, inputs): - # inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd + def get_loss(self, head_outs, inputs): + fam_cls_list, fam_reg_list, odm_cls_list, odm_reg_list, \ + num_anchors_list, base_anchors_list, refine_anchors_list = head_outs # compute loss fam_cls_loss_lst = [] @@ -761,29 +623,27 @@ class S2ANetHead(nn.Layer): odm_cls_loss_lst = [] odm_reg_loss_lst = [] - im_shape = inputs['im_shape'] - for im_id in range(im_shape.shape[0]): - np_im_shape = inputs['im_shape'][im_id].numpy() - np_scale_factor = inputs['scale_factor'][im_id].numpy() + batch = len(inputs['gt_rbox']) + for i in range(batch): # data_format: (xc, yc, w, h, theta) - gt_bboxes = inputs['gt_rbox'][im_id].numpy() - gt_labels = inputs['gt_class'][im_id].numpy() - is_crowd = inputs['is_crowd'][im_id].numpy() + gt_mask = inputs['pad_gt_mask'][i, :, 0] + gt_idx = paddle.nonzero(gt_mask).squeeze(-1) + gt_bboxes = paddle.gather(inputs['gt_rbox'][i], gt_idx).numpy() + gt_labels = paddle.gather(inputs['gt_class'][i], gt_idx).numpy() + is_crowd = paddle.gather(inputs['is_crowd'][i], gt_idx).numpy() gt_labels = gt_labels + 1 - # featmap_sizes - anchors_list_all = np.concatenate(self.base_anchors_list) - - # get im_feat - fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]] - fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]] - odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]] - odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]] - im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list, - odm_cls_feats_list, odm_reg_feats_list) + anchors_per_image = np.concatenate(base_anchors_list) + fam_cls_per_image = [t[i] for t in fam_cls_list] + fam_reg_per_image = [t[i] for t in fam_reg_list] + odm_cls_per_image = [t[i] for t in odm_cls_list] + odm_reg_per_image = [t[i] for t in odm_reg_list] + im_s2anet_head_out = (fam_cls_per_image, fam_reg_per_image, + odm_cls_per_image, odm_reg_per_image, + num_anchors_list) # FAM - im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes, + im_fam_target = self.anchor_assign(anchors_per_image, gt_bboxes, gt_labels, is_crowd) if im_fam_target is not None: im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss( @@ -792,11 +652,10 @@ class S2ANetHead(nn.Layer): fam_reg_loss_lst.append(im_fam_reg_loss) # ODM - np_refine_anchors_list = paddle.concat( - self.refine_anchor_list).numpy() - np_refine_anchors_list = np.concatenate(np_refine_anchors_list) - np_refine_anchors_list = np_refine_anchors_list.reshape(-1, 5) - im_odm_target = self.anchor_assign(np_refine_anchors_list, + refine_anchors_per_image = [t[i] for t in refine_anchors_list] + refine_anchors_per_image = paddle.concat( + refine_anchors_per_image).numpy() + im_odm_target = self.anchor_assign(refine_anchors_per_image, gt_bboxes, gt_labels, is_crowd) if im_odm_target is not None: @@ -804,116 +663,38 @@ class S2ANetHead(nn.Layer): im_odm_target, im_s2anet_head_out, self.reg_loss_type) odm_cls_loss_lst.append(im_odm_cls_loss) odm_reg_loss_lst.append(im_odm_reg_loss) - fam_cls_loss = paddle.add_n(fam_cls_loss_lst) - fam_reg_loss = paddle.add_n(fam_reg_loss_lst) - odm_cls_loss = paddle.add_n(odm_cls_loss_lst) - odm_reg_loss = paddle.add_n(odm_reg_loss_lst) + + fam_cls_loss = paddle.add_n(fam_cls_loss_lst) / batch + fam_reg_loss = paddle.add_n(fam_reg_loss_lst) / batch + odm_cls_loss = paddle.add_n(odm_cls_loss_lst) / batch + odm_reg_loss = paddle.add_n(odm_reg_loss_lst) / batch + loss = fam_cls_loss + fam_reg_loss + odm_cls_loss + odm_reg_loss + return { + 'loss': loss, 'fam_cls_loss': fam_cls_loss, 'fam_reg_loss': fam_reg_loss, 'odm_cls_loss': odm_cls_loss, 'odm_reg_loss': odm_reg_loss } - def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre, - cls_out_channels, use_sigmoid_cls): - assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) - - mlvl_bboxes = [] - mlvl_scores = [] - - idx = 0 - for cls_score, bbox_pred, anchors in zip(cls_score_list, bbox_pred_list, - mlvl_anchors): - cls_score = paddle.reshape(cls_score, [-1, cls_out_channels]) - if use_sigmoid_cls: - scores = F.sigmoid(cls_score) - else: - scores = F.softmax(cls_score, axis=-1) - - # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5) - bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0]) - bbox_pred = paddle.reshape(bbox_pred, [-1, 5]) - anchors = paddle.reshape(anchors, [-1, 5]) - - if scores.shape[0] > nms_pre: - # Get maximum scores for foreground classes. - if use_sigmoid_cls: - max_scores = paddle.max(scores, axis=1) - else: - max_scores = paddle.max(scores[:, 1:], axis=1) - - topk_val, topk_inds = paddle.topk(max_scores, nms_pre) - anchors = paddle.gather(anchors, topk_inds) - bbox_pred = paddle.gather(bbox_pred, topk_inds) - scores = paddle.gather(scores, topk_inds) - - bbox_delta = paddle.reshape(bbox_pred, [-1, 5]) - bboxes = self.delta2rbox(anchors, bbox_delta) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - - idx += 1 - - mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0) - mlvl_scores = paddle.concat(mlvl_scores) - - return mlvl_scores, mlvl_bboxes - - def rect2rbox(self, bboxes): - """ - :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax) - :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle) - """ - bboxes = paddle.reshape(bboxes, [-1, 4]) - num_boxes = paddle.shape(bboxes)[0] - x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0 - y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0 - edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0]) - edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1]) - - rbox_w = paddle.maximum(edges1, edges2) - rbox_h = paddle.minimum(edges1, edges2) - - # set angle - inds = edges1 < edges2 - inds = paddle.cast(inds, 'int32') - rboxes_angle = inds * np.pi / 2.0 - - rboxes = paddle.stack( - (x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1) - return rboxes - - # deltas to rbox - def delta2rbox(self, rrois, deltas, wh_ratio_clip=1e-6): - """ - :param rrois: (cx, cy, w, h, theta) - :param deltas: (dx, dy, dw, dh, dtheta) - :param means: means of anchor - :param stds: stds of anchor - :param wh_ratio_clip: clip threshold of wh_ratio - :return: + def bbox_decode(self, preds, anchors, wh_ratio_clip=1e-6): + """decode bbox from deltas + Args: + preds: [B, L, 5] + anchors: [1, L, 5] + return: + bboxes: [B, L, 5] """ - deltas = paddle.reshape(deltas, [-1, 5]) - rrois = paddle.reshape(rrois, [-1, 5]) - # fix dy2st bug denorm_deltas = deltas * self.stds + self.means - denorm_deltas = paddle.add( - paddle.multiply(deltas, self.stds), self.means) - - dx = denorm_deltas[:, 0] - dy = denorm_deltas[:, 1] - dw = denorm_deltas[:, 2] - dh = denorm_deltas[:, 3] - dangle = denorm_deltas[:, 4] + preds = paddle.add(paddle.multiply(preds, self.stds), self.means) + + dx, dy, dw, dh, dangle = paddle.split(preds, 5, axis=-1) max_ratio = np.abs(np.log(wh_ratio_clip)) dw = paddle.clip(dw, min=-max_ratio, max=max_ratio) dh = paddle.clip(dh, min=-max_ratio, max=max_ratio) - rroi_x = rrois[:, 0] - rroi_y = rrois[:, 1] - rroi_w = rrois[:, 2] - rroi_h = rrois[:, 3] - rroi_angle = rrois[:, 4] + rroi_x, rroi_y, rroi_w, rroi_h, rroi_angle = paddle.split( + anchors, 5, axis=-1) gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin( rroi_angle) + rroi_x @@ -923,127 +704,5 @@ class S2ANetHead(nn.Layer): gh = rroi_h * dh.exp() ga = np.pi * dangle + rroi_angle ga = (ga + np.pi / 4) % np.pi - np.pi / 4 - ga = paddle.to_tensor(ga) - gw = paddle.to_tensor(gw, dtype='float32') - gh = paddle.to_tensor(gh, dtype='float32') - bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1) - return bboxes - - def bbox_decode(self, bbox_preds, anchors): - """decode bbox from deltas - Args: - bbox_preds: [N,H,W,5] - anchors: [H*W,5] - return: - bboxes: [N,H,W,5] - """ - num_imgs, H, W, _ = bbox_preds.shape - bbox_delta = paddle.reshape(bbox_preds, [-1, 5]) - bboxes = self.delta2rbox(anchors, bbox_delta) + bboxes = paddle.concat([gx, gy, gw, gh, ga], axis=-1) return bboxes - - def trace(self, A): - tr = paddle.diagonal(A, axis1=-2, axis2=-1) - tr = paddle.sum(tr, axis=-1) - return tr - - def sqrt_newton_schulz_autograd(self, A, numIters): - A_shape = A.shape - batchSize = A_shape[0] - dim = A_shape[1] - - normA = A * A - normA = paddle.sum(normA, axis=1) - normA = paddle.sum(normA, axis=1) - normA = paddle.sqrt(normA) - normA1 = normA.reshape([batchSize, 1, 1]) - Y = paddle.divide(A, paddle.expand_as(normA1, A)) - I = paddle.eye(dim, dim).reshape([1, dim, dim]) - l0 = [] - for i in range(batchSize): - l0.append(I) - I = paddle.concat(l0, axis=0) - I.stop_gradient = False - Z = paddle.eye(dim, dim).reshape([1, dim, dim]) - l1 = [] - for i in range(batchSize): - l1.append(Z) - Z = paddle.concat(l1, axis=0) - Z.stop_gradient = False - - for i in range(numIters): - T = 0.5 * (3.0 * I - Z.bmm(Y)) - Y = Y.bmm(T) - Z = T.bmm(Z) - sA = Y * paddle.sqrt(normA1).reshape([batchSize, 1, 1]) - sA = paddle.expand_as(sA, A) - return sA - - def wasserstein_distance_sigma(sigma1, sigma2): - wasserstein_distance_item2 = paddle.matmul( - sigma1, sigma1) + paddle.matmul( - sigma2, sigma2) - 2 * self.sqrt_newton_schulz_autograd( - paddle.matmul( - paddle.matmul(sigma1, paddle.matmul(sigma2, sigma2)), - sigma1), 10) - wasserstein_distance_item2 = self.trace(wasserstein_distance_item2) - - return wasserstein_distance_item2 - - def xywhr2xyrs(self, xywhr): - xywhr = paddle.reshape(xywhr, [-1, 5]) - xy = xywhr[:, :2] - wh = paddle.clip(xywhr[:, 2:4], min=1e-7, max=1e7) - r = xywhr[:, 4] - cos_r = paddle.cos(r) - sin_r = paddle.sin(r) - R = paddle.stack( - (cos_r, -sin_r, sin_r, cos_r), axis=-1).reshape([-1, 2, 2]) - S = 0.5 * paddle.nn.functional.diag_embed(wh) - return xy, R, S - - def gwd_loss(self, - pred, - target, - fun='log', - tau=1.0, - alpha=1.0, - normalize=False): - - xy_p, R_p, S_p = self.xywhr2xyrs(pred) - xy_t, R_t, S_t = self.xywhr2xyrs(target) - - xy_distance = (xy_p - xy_t).square().sum(axis=-1) - - Sigma_p = R_p.matmul(S_p.square()).matmul(R_p.transpose([0, 2, 1])) - Sigma_t = R_t.matmul(S_t.square()).matmul(R_t.transpose([0, 2, 1])) - - whr_distance = paddle.diagonal( - S_p, axis1=-2, axis2=-1).square().sum(axis=-1) - - whr_distance = whr_distance + paddle.diagonal( - S_t, axis1=-2, axis2=-1).square().sum(axis=-1) - _t = Sigma_p.matmul(Sigma_t) - - _t_tr = paddle.diagonal(_t, axis1=-2, axis2=-1).sum(axis=-1) - _t_det_sqrt = paddle.diagonal(S_p, axis1=-2, axis2=-1).prod(axis=-1) - _t_det_sqrt = _t_det_sqrt * paddle.diagonal( - S_t, axis1=-2, axis2=-1).prod(axis=-1) - whr_distance = whr_distance + (-2) * ( - (_t_tr + 2 * _t_det_sqrt).clip(0).sqrt()) - - distance = (xy_distance + alpha * alpha * whr_distance).clip(0) - - if normalize: - wh_p = pred[..., 2:4].clip(min=1e-7, max=1e7) - wh_t = target[..., 2:4].clip(min=1e-7, max=1e7) - scale = ((wh_p.log() + wh_t.log()).sum(dim=-1) / 4).exp() - distance = distance / scale - - if fun == 'log': - distance = paddle.log1p(distance) - - if tau >= 1.0: - return 1 - 1 / (tau + distance) - - return distance diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 0fac4d9ca..a3253df77 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -39,6 +39,81 @@ def _to_list(l): return [l] +class AlignConv(nn.Layer): + def __init__(self, in_channels, out_channels, kernel_size=3, groups=1): + super(AlignConv, self).__init__() + self.kernel_size = kernel_size + self.align_conv = paddle.vision.ops.DeformConv2D( + in_channels, + out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + bias_attr=None) + + @paddle.no_grad() + def get_offset(self, anchors, featmap_size, stride): + """ + Args: + anchors: [B, L, 5] xc,yc,w,h,angle + featmap_size: (feat_h, feat_w) + stride: 8 + Returns: + + """ + batch = anchors.shape[0] + dtype = anchors.dtype + feat_h, feat_w = featmap_size + pad = (self.kernel_size - 1) // 2 + idx = paddle.arange(-pad, pad + 1, dtype=dtype) + + yy, xx = paddle.meshgrid(idx, idx) + xx = paddle.reshape(xx, [-1]) + yy = paddle.reshape(yy, [-1]) + + # get sampling locations of default conv + xc = paddle.arange(0, feat_w, dtype=dtype) + yc = paddle.arange(0, feat_h, dtype=dtype) + yc, xc = paddle.meshgrid(yc, xc) + + xc = paddle.reshape(xc, [-1, 1]) + yc = paddle.reshape(yc, [-1, 1]) + x_conv = xc + xx + y_conv = yc + yy + + # get sampling locations of anchors + x_ctr, y_ctr, w, h, a = paddle.split(anchors, 5, axis=-1) + x_ctr = x_ctr / stride + y_ctr = y_ctr / stride + w_s = w / stride + h_s = h / stride + cos, sin = paddle.cos(a), paddle.sin(a) + dw, dh = w_s / self.kernel_size, h_s / self.kernel_size + x, y = dw * xx, dh * yy + xr = cos * x - sin * y + yr = sin * x + cos * y + x_anchor, y_anchor = xr + x_ctr, yr + y_ctr + # get offset filed + offset_x = x_anchor - x_conv + offset_y = y_anchor - y_conv + offset = paddle.stack([offset_y, offset_x], axis=-1) + offset = offset.reshape( + [batch, feat_h, feat_w, self.kernel_size * self.kernel_size * 2]) + offset = offset.transpose([0, 3, 1, 2]) + + return offset + + def forward(self, x, refine_anchors, featmap_size, stride): + batch = paddle.shape(x)[0].numpy() + offset = self.get_offset(refine_anchors, featmap_size, stride) + if self.training: + x = F.relu(self.align_conv(x, offset.detach())) + else: + x = F.relu(self.align_conv(x, offset)) + return x + + class DeformableConvV2(nn.Layer): def __init__(self, in_channels, diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 27890c17e..15060e7a8 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -27,8 +27,8 @@ except Exception: __all__ = [ 'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess', - 'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess', - 'DETRBBoxPostProcess', 'SparsePostProcess' + 'JDEBBoxPostProcess', 'CenterNetPostProcess', 'DETRBBoxPostProcess', + 'SparsePostProcess' ] @@ -294,109 +294,6 @@ class FCOSPostProcess(object): return bbox_pred, bbox_num -@register -class S2ANetBBoxPostProcess(nn.Layer): - __shared__ = ['num_classes'] - __inject__ = ['nms'] - - def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0, nms=None): - super(S2ANetBBoxPostProcess, self).__init__() - self.num_classes = num_classes - self.nms_pre = nms_pre - self.min_bbox_size = min_bbox_size - self.nms = nms - self.origin_shape_list = [] - self.fake_pred_cls_score_bbox = paddle.to_tensor( - np.array( - [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], - dtype='float32')) - self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) - - def forward(self, pred_scores, pred_bboxes): - """ - pred_scores : [N, M] score - pred_bboxes : [N, 5] xc, yc, w, h, a - im_shape : [N, 2] im_shape - scale_factor : [N, 2] scale_factor - """ - pred_ploys0 = rbox2poly(pred_bboxes) - pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0) - - # pred_scores [NA, 16] --> [16, NA] - pred_scores0 = paddle.transpose(pred_scores, [1, 0]) - pred_scores = paddle.unsqueeze(pred_scores0, axis=0) - - pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores, - self.num_classes) - # Prevent empty bbox_pred from decode or NMS. - # Bboxes and score before NMS may be empty due to the score threshold. - if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[ - 1] <= 1: - pred_cls_score_bbox = self.fake_pred_cls_score_bbox - bbox_num = self.fake_bbox_num - - pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10]) - return pred_cls_score_bbox, bbox_num - - def get_pred(self, bboxes, bbox_num, im_shape, scale_factor): - """ - Rescale, clip and filter the bbox from the output of NMS to - get final prediction. - Args: - bboxes(Tensor): bboxes [N, 10] - bbox_num(Tensor): bbox_num - im_shape(Tensor): [1 2] - scale_factor(Tensor): [1 2] - Returns: - bbox_pred(Tensor): The output is the prediction with shape [N, 8] - including labels, scores and bboxes. The size of - bboxes are corresponding to the original image. - """ - origin_shape = paddle.floor(im_shape / scale_factor + 0.5) - - origin_shape_list = [] - scale_factor_list = [] - # scale_factor: scale_y, scale_x - for i in range(bbox_num.shape[0]): - expand_shape = paddle.expand(origin_shape[i:i + 1, :], - [bbox_num[i], 2]) - scale_y, scale_x = scale_factor[i][0], scale_factor[i][1] - scale = paddle.concat([ - scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x, - scale_y - ]) - expand_scale = paddle.expand(scale, [bbox_num[i], 8]) - origin_shape_list.append(expand_shape) - scale_factor_list.append(expand_scale) - - origin_shape_list = paddle.concat(origin_shape_list) - scale_factor_list = paddle.concat(scale_factor_list) - - # bboxes: [N, 10], label, score, bbox - pred_label_score = bboxes[:, 0:2] - pred_bbox = bboxes[:, 2:] - - # rescale bbox to original image - pred_bbox = pred_bbox.reshape([-1, 8]) - scaled_bbox = pred_bbox / scale_factor_list - origin_h = origin_shape_list[:, 0] - origin_w = origin_shape_list[:, 1] - - bboxes = scaled_bbox - zeros = paddle.zeros_like(origin_h) - x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros) - y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros) - x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros) - y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros) - x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros) - y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros) - x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros) - y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros) - pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1) - pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1) - return pred_result - - @register class JDEBBoxPostProcess(nn.Layer): __shared__ = ['num_classes'] diff --git a/ppdet/modeling/proposal_generator/anchor_generator.py b/ppdet/modeling/proposal_generator/anchor_generator.py index 94fd34600..9a8e24ea3 100644 --- a/ppdet/modeling/proposal_generator/anchor_generator.py +++ b/ppdet/modeling/proposal_generator/anchor_generator.py @@ -19,10 +19,11 @@ import math import paddle import paddle.nn as nn +import numpy as np from ppdet.core.workspace import register -__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator'] +__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator'] @register @@ -153,3 +154,113 @@ class RetinaAnchorGenerator(AnchorGenerator): strides=strides, variance=variance, offset=offset) + + +@register +class S2ANetAnchorGenerator(nn.Layer): + """ + AnchorGenerator by paddle + """ + + def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): + super(S2ANetAnchorGenerator, self).__init__() + self.base_size = base_size + self.scales = paddle.to_tensor(scales) + self.ratios = paddle.to_tensor(ratios) + self.scale_major = scale_major + self.ctr = ctr + self.base_anchors = self.gen_base_anchors() + + @property + def num_base_anchors(self): + return self.base_anchors.shape[0] + + def gen_base_anchors(self): + w = self.base_size + h = self.base_size + if self.ctr is None: + x_ctr = 0.5 * (w - 1) + y_ctr = 0.5 * (h - 1) + else: + x_ctr, y_ctr = self.ctr + + h_ratios = paddle.sqrt(self.ratios) + w_ratios = 1 / h_ratios + if self.scale_major: + ws = (w * w_ratios[:] * self.scales[:]).reshape([-1]) + hs = (h * h_ratios[:] * self.scales[:]).reshape([-1]) + else: + ws = (w * self.scales[:] * w_ratios[:]).reshape([-1]) + hs = (h * self.scales[:] * h_ratios[:]).reshape([-1]) + + base_anchors = paddle.stack( + [ + x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) + ], + axis=-1) + base_anchors = paddle.round(base_anchors) + return base_anchors + + def _meshgrid(self, x, y, row_major=True): + yy, xx = paddle.meshgrid(y, x) + yy = yy.reshape([-1]) + xx = xx.reshape([-1]) + if row_major: + return xx, yy + else: + return yy, xx + + def forward(self, featmap_size, stride=16): + # featmap_size*stride project it to original area + + feat_h = featmap_size[0] + feat_w = featmap_size[1] + shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride + shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) + + all_anchors = self.base_anchors[:, :] + shifts[:, :] + all_anchors = all_anchors.cast(paddle.float32).reshape( + [feat_h * feat_w, 4]) + all_anchors = self.rect2rbox(all_anchors) + return all_anchors + + def valid_flags(self, featmap_size, valid_size): + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = paddle.zeros([feat_w], dtype='int32') + valid_y = paddle.zeros([feat_h], dtype='int32') + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + valid = paddle.reshape(valid, [-1, 1]) + valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1]) + return valid + + def rect2rbox(self, bboxes): + """ + :param bboxes: shape (L, 4) (xmin, ymin, xmax, ymax) + :return: dbboxes: shape (L, 5) (x_ctr, y_ctr, w, h, angle) + """ + x1, y1, x2, y2 = paddle.split(bboxes, 4, axis=-1) + + x_ctr = (x1 + x2) / 2.0 + y_ctr = (y1 + y2) / 2.0 + edges1 = paddle.abs(x2 - x1) + edges2 = paddle.abs(y2 - y1) + + rbox_w = paddle.maximum(edges1, edges2) + rbox_h = paddle.minimum(edges1, edges2) + + # set angle + inds = edges1 < edges2 + inds = paddle.cast(inds, paddle.float32) + rboxes_angle = inds * np.pi / 2.0 + + rboxes = paddle.concat( + (x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=-1) + return rboxes diff --git a/ppdet/modeling/proposal_generator/target_layer.py b/ppdet/modeling/proposal_generator/target_layer.py index 201c8bf86..edcf97359 100644 --- a/ppdet/modeling/proposal_generator/target_layer.py +++ b/ppdet/modeling/proposal_generator/target_layer.py @@ -365,21 +365,11 @@ class RBoxAssigner(object): def assign_anchor(self, anchors, gt_bboxes, - gt_lables, + gt_labels, pos_iou_thr, neg_iou_thr, min_iou_thr=0.0, ignore_iof_thr=-2): - """ - - Args: - anchors: - gt_bboxes:[M, 5] rc,yc,w,h,angle - gt_lables: - - Returns: - - """ assert anchors.shape[1] == 4 or anchors.shape[1] == 5 assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5 anchors_xc_yc = anchors @@ -428,12 +418,12 @@ class RBoxAssigner(object): # (4) assign max_iou as pos_ids >=0 anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds] # gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr) - labels[gt_bbox_anchor_iou_inds] = gt_lables[anchor_gt_bbox_iou_inds] + labels[gt_bbox_anchor_iou_inds] = gt_labels[anchor_gt_bbox_iou_inds] # (5) assign >= pos_iou_thr as pos_ids iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids] - labels[iou_pos_iou_thr_ids] = gt_lables[iou_pos_iou_thr_ids_box_inds] + labels[iou_pos_iou_thr_ids] = gt_labels[iou_pos_iou_thr_ids_box_inds] return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd): -- GitLab