diff --git a/configs/tood/README.md b/configs/tood/README.md index f23b1844ab5e1124f39dfd2bf4335b2f401b9db0..9eb87fdf92ba36901e8f50b0d5a8c6fdb0da4104 100644 --- a/configs/tood/README.md +++ b/configs/tood/README.md @@ -11,7 +11,7 @@ TOOD is an object detection model. We reproduced the model of the paper. | Backbone | Model | Images/GPU | Inf time (fps) | Box AP | Config | Download | |:------:|:--------:|:--------:|:--------------:|:------:|:------:|:--------:| -| R-50 | TOOD | 4 | --- | 42.8 | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/tood/tood_r50_fpn_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/tood_r50_fpn_1x_coco.pdparams) | +| R-50 | TOOD | 4 | --- | 42.5 | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/tood/tood_r50_fpn_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/tood_r50_fpn_1x_coco.pdparams) | **Notes:** diff --git a/configs/tood/_base_/tood_reader.yml b/configs/tood/_base_/tood_reader.yml index cda3cb80db4144fe3469c3844b3b698e1357e539..4bfe3b4f9fc1f42be0a65fa90ee97ebd67f9c6c3 100644 --- a/configs/tood/_base_/tood_reader.yml +++ b/configs/tood/_base_/tood_reader.yml @@ -3,7 +3,7 @@ TrainReader: sample_transforms: - Decode: {} - RandomFlip: {prob: 0.5} - - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} + - Resize: {target_size: [800, 1333], keep_ratio: true} - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - Permute: {} batch_transforms: @@ -18,7 +18,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - Resize: {target_size: [800, 1333], keep_ratio: True} - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - Permute: {} batch_transforms: @@ -30,7 +30,7 @@ EvalReader: TestReader: sample_transforms: - Decode: {} - - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - Resize: {target_size: [800, 1333], keep_ratio: True} - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - Permute: {} batch_transforms: diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 752a07efa48d688fbde2523253a953b28060a879..c284bfe2d9d6797f7f608ffcb6b9dfde132e1dca 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -46,6 +46,7 @@ SUPPORT_MODELS = { 'GFL', 'PicoDet', 'CenterNet', + 'TOOD', } @@ -680,7 +681,7 @@ def predict_video(detector, camera_id): if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) out_path = os.path.join(FLAGS.output_dir, video_out_name) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') + fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) index = 1 while (1): diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index e1cf64638089c3e98fb7d08af7a3e39246cf6c48..4fe623d17fbd3b43cd4680f143d02598d5cac95b 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -46,6 +46,7 @@ TRT_MIN_SUBGRAPH = { 'GFL': 16, 'PicoDet': 3, 'CenterNet': 5, + 'TOOD': 5, } KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] diff --git a/ppdet/modeling/assigners/utils.py b/ppdet/modeling/assigners/utils.py index 3448d9d8ae5825bfab2814970e8c52b7ba54548b..4482f4c3f70b9ec4f0e1abb863a339170108fc4c 100644 --- a/ppdet/modeling/assigners/utils.py +++ b/ppdet/modeling/assigners/utils.py @@ -19,6 +19,12 @@ from __future__ import print_function import paddle import paddle.nn.functional as F +__all__ = [ + 'pad_gt', 'gather_topk_anchors', 'check_points_inside_bboxes', + 'compute_max_iou_anchor', 'compute_max_iou_gt', + 'generate_anchors_for_grid_cell' +] + def pad_gt(gt_labels, gt_bboxes, gt_scores=None): r""" Pad 0 in gt_labels and gt_bboxes. @@ -147,3 +153,42 @@ def compute_max_iou_gt(ious): max_iou_index = ious.argmax(axis=-1) is_max_iou = F.one_hot(max_iou_index, num_anchors) return is_max_iou.astype(ious.dtype) + + +def generate_anchors_for_grid_cell(feats, + fpn_strides, + grid_cell_size=5.0, + grid_cell_offset=0.5): + r""" + Like ATSS, generate anchors based on grid size. + Args: + feats (List[Tensor]): shape[s, (b, c, h, w)] + fpn_strides (tuple|list): shape[s], stride for each scale feature + grid_cell_size (float): anchor size + grid_cell_offset (float): The range is between 0 and 1. + Returns: + anchors (List[Tensor]): shape[s, (l, 4)] + num_anchors_list (List[int]): shape[s] + stride_tensor_list (List[Tensor]): shape[s, (l, 1)] + """ + assert len(feats) == len(fpn_strides) + anchors = [] + num_anchors_list = [] + stride_tensor_list = [] + for feat, stride in zip(feats, fpn_strides): + _, _, h, w = feat.shape + cell_half_size = grid_cell_size * stride * 0.5 + shift_x = (paddle.arange(end=w) + grid_cell_offset) * stride + shift_y = (paddle.arange(end=h) + grid_cell_offset) * stride + shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) + anchor = paddle.stack( + [ + shift_x - cell_half_size, shift_y - cell_half_size, + shift_x + cell_half_size, shift_y + cell_half_size + ], + axis=-1).astype(feat.dtype) + anchors.append(anchor.reshape([-1, 4])) + num_anchors_list.append(len(anchors[-1])) + stride_tensor_list.append( + paddle.full([num_anchors_list[-1], 1], stride)) + return anchors, num_anchors_list, stride_tensor_list diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index e040ba69b755fdc5c329413343edaba463ba7925..49a2e281fa1485ad6209215e09dfd0e6dfc0a31b 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -748,6 +748,28 @@ def bbox_center(boxes): Returns: Tensor: boxes centers with shape (N, 2), "cx, cy" format. """ - boxes_cx = (boxes[:, 0] + boxes[:, 2]) / 2 - boxes_cy = (boxes[:, 1] + boxes[:, 3]) / 2 + boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2 + boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2 return paddle.stack([boxes_cx, boxes_cy], axis=-1) + + +def batch_distance2bbox(points, distance, max_shapes=None): + """Decode distance prediction to bounding box for batch. + Args: + points (Tensor): [B, ..., 2] + distance (Tensor): [B, ..., 4] + max_shapes (tuple): [B, 2], "h,w" format, Shape of the image. + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shapes is not None: + for i, max_shape in enumerate(max_shapes): + x1[i] = x1[i].clip(min=0, max=max_shape[1]) + y1[i] = y1[i].clip(min=0, max=max_shape[0]) + x2[i] = x2[i].clip(min=0, max=max_shape[1]) + y2[i] = y2[i].clip(min=0, max=max_shape[0]) + return paddle.stack([x1, y1, x2, y2], -1) diff --git a/ppdet/modeling/heads/tood_head.py b/ppdet/modeling/heads/tood_head.py index b9dbd17e36f920d2df7b01d09eb7a461b9d68365..f2cb2970a2d30f33376f3b1ab7bf4f1dab4be5a1 100644 --- a/ppdet/modeling/heads/tood_head.py +++ b/ppdet/modeling/heads/tood_head.py @@ -24,10 +24,11 @@ from paddle.nn.initializer import Constant from ppdet.core.workspace import register from ..initializer import normal_, constant_, bias_init_with_prob -from ppdet.modeling.bbox_utils import bbox_center +from ppdet.modeling.bbox_utils import bbox_center, batch_distance2bbox from ..losses import GIoULoss -from paddle.vision.ops import deform_conv2d from ppdet.modeling.layers import ConvNormLayer +from ppdet.modeling.ops import get_static_shape +from ppdet.modeling.assigners.utils import generate_anchors_for_grid_cell class ScaleReg(nn.Layer): @@ -84,25 +85,13 @@ class TaskDecomposition(nn.Layer): normal_(self.la_conv1.weight, std=0.001) normal_(self.la_conv2.weight, std=0.001) - def forward(self, feat, avg_feat=None): - b, _, h, w = feat.shape - if avg_feat is None: - avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + def forward(self, feat, avg_feat): + b, _, h, w = get_static_shape(feat) weight = F.relu(self.la_conv1(avg_feat)) - weight = F.sigmoid(self.la_conv2(weight)) - - # here new_conv_weight = layer_attention_weight * conv_weight - # in order to save memory and FLOPs. - conv_weight = weight.reshape([b, 1, self.stacked_convs, 1]) * \ - self.reduction_conv.conv.weight.reshape( - [1, self.feat_channels, self.stacked_convs, self.feat_channels]) - conv_weight = conv_weight.reshape( - [b, self.feat_channels, self.in_channels]) - feat = feat.reshape([b, self.in_channels, h * w]) - feat = paddle.bmm(conv_weight, feat).reshape( - [b, self.feat_channels, h, w]) - if self.norm_type is not None: - feat = self.reduction_conv.norm(feat) + weight = F.sigmoid(self.la_conv2(weight)).unsqueeze(-1) + feat = paddle.reshape( + feat, [b, self.stacked_convs, self.feat_channels, h, w]) * weight + feat = self.reduction_conv(feat.flatten(1, 2)) feat = F.relu(feat) return feat @@ -211,81 +200,32 @@ class TOODHead(nn.Layer): normal_(self.cls_prob_conv2.weight, std=0.01) constant_(self.cls_prob_conv2.bias, bias_cls) normal_(self.reg_offset_conv1.weight, std=0.001) - normal_(self.reg_offset_conv2.weight, std=0.001) + constant_(self.reg_offset_conv2.weight) constant_(self.reg_offset_conv2.bias) - def _generate_anchors(self, feats): - anchors, num_anchors_list = [], [] - stride_tensor_list = [] - for feat, stride in zip(feats, self.fpn_strides): - _, _, h, w = feat.shape - cell_half_size = self.grid_cell_scale * stride * 0.5 - shift_x = (paddle.arange(end=w) + self.grid_cell_offset) * stride - shift_y = (paddle.arange(end=h) + self.grid_cell_offset) * stride - shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) - anchor = paddle.stack( - [ - shift_x - cell_half_size, shift_y - cell_half_size, - shift_x + cell_half_size, shift_y + cell_half_size - ], - axis=-1) - anchors.append(anchor.reshape([-1, 4])) - num_anchors_list.append(len(anchors[-1])) - stride_tensor_list.append( - paddle.full([num_anchors_list[-1], 1], stride)) - return anchors, num_anchors_list, stride_tensor_list - - @staticmethod - def _batch_distance2bbox(points, distance, max_shapes=None): - """Decode distance prediction to bounding box. - Args: - points (Tensor): [B, l, 2] - distance (Tensor): [B, l, 4] - max_shapes (tuple): [B, 2], "h w" format, Shape of the image. - Returns: - Tensor: Decoded bboxes. - """ - x1 = points[:, :, 0] - distance[:, :, 0] - y1 = points[:, :, 1] - distance[:, :, 1] - x2 = points[:, :, 0] + distance[:, :, 2] - y2 = points[:, :, 1] + distance[:, :, 3] - bboxes = paddle.stack([x1, y1, x2, y2], -1) - if max_shapes is not None: - out_bboxes = [] - for bbox, max_shape in zip(bboxes, max_shapes): - bbox[:, 0] = bbox[:, 0].clip(min=0, max=max_shape[1]) - bbox[:, 1] = bbox[:, 1].clip(min=0, max=max_shape[0]) - bbox[:, 2] = bbox[:, 2].clip(min=0, max=max_shape[1]) - bbox[:, 3] = bbox[:, 3].clip(min=0, max=max_shape[0]) - out_bboxes.append(bbox) - out_bboxes = paddle.stack(out_bboxes) - return out_bboxes - return bboxes - - @staticmethod - def _deform_sampling(feat, offset): - """ Sampling the feature according to offset. - Args: - feat (Tensor): Feature - offset (Tensor): Spatial offset for for feature sampliing - """ - # it is an equivalent implementation of bilinear interpolation - # you can also use F.grid_sample instead - c = feat.shape[1] - weight = paddle.ones([c, 1, 1, 1]) - y = deform_conv2d(feat, offset, weight, deformable_groups=c, groups=c) - return y + def _reg_grid_sample(self, feat, offset, anchor_points): + b, _, h, w = get_static_shape(feat) + feat = paddle.reshape(feat, [-1, 1, h, w]) + offset = paddle.reshape(offset, [-1, 2, h, w]).transpose([0, 2, 3, 1]) + grid_shape = paddle.concat([w, h]).astype('float32') + grid = (offset + anchor_points) / grid_shape + grid = 2 * grid.clip(0., 1.) - 1 + feat = F.grid_sample(feat, grid) + feat = paddle.reshape(feat, [b, -1, h, w]) + return feat def forward(self, feats): assert len(feats) == len(self.fpn_strides), \ "The size of feats is not equal to size of fpn_strides" - anchors, num_anchors_list, stride_tensor_list = self._generate_anchors( - feats) + anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell( + feats, self.fpn_strides, self.grid_cell_scale, + self.grid_cell_offset) + cls_score_list, bbox_pred_list = [], [] for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs, anchors, self.fpn_strides): - b, _, h, w = feat.shape + b, _, h, w = get_static_shape(feat) inter_feats = [] for inter_conv in self.inter_convs: feat = F.relu(inter_conv(feat)) @@ -309,16 +249,16 @@ class TOODHead(nn.Layer): # reg prediction and alignment reg_dist = scale_reg(self.tood_reg(reg_feat).exp()) - reg_dist = reg_dist.transpose([0, 2, 3, 1]).reshape([b, -1, 4]) + reg_dist = reg_dist.flatten(2).transpose([0, 2, 1]) anchor_centers = bbox_center(anchor).unsqueeze(0) / stride - reg_bbox = self._batch_distance2bbox( - anchor_centers.tile([b, 1, 1]), reg_dist) + reg_bbox = batch_distance2bbox(anchor_centers, reg_dist) if self.use_align_head: - reg_bbox = reg_bbox.reshape([b, h, w, 4]).transpose( - [0, 3, 1, 2]) reg_offset = F.relu(self.reg_offset_conv1(feat)) reg_offset = self.reg_offset_conv2(reg_offset) - bbox_pred = self._deform_sampling(reg_bbox, reg_offset) + reg_bbox = reg_bbox.transpose([0, 2, 1]).reshape([b, 4, h, w]) + anchor_centers = anchor_centers.reshape([1, h, w, 2]) + bbox_pred = self._reg_grid_sample(reg_bbox, reg_offset, + anchor_centers) bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1]) else: bbox_pred = reg_bbox diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index ba4321e6b58c288e5c0fd841e350de670c48c847..e6b8ad987fc853d8789d8fe35ea9369670617636 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -1600,3 +1600,9 @@ def channel_shuffle(x, groups): x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) return x + + +def get_static_shape(tensor): + shape = paddle.shape(tensor) + shape.stop_gradient = True + return shape