diff --git a/configs/mot/README.md b/configs/mot/README.md index 09ad7dcb5f8330788bb97fcbe17f21b73554f23f..cbf76371ae76504b1b6d02a47c3fb13997888f5c 100644 --- a/configs/mot/README.md +++ b/configs/mot/README.md @@ -140,14 +140,17 @@ If you use a stronger detection model, you can get better results. Each txt is t ### Results on MOT-16 Test Set | backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.9 | 74.7 | 1021 | 11425 | 31475 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_30e_1088x608.pdparams) | [config](./fairmot_enhance_dla34_30e_1088x608.yml) | | HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | ### Results on MOT-17 Test Set | backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.3 | 74.2 | 3270 | 29112 | 106749 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_30e_1088x608.pdparams) | [config](./fairmot_enhance_dla34_30e_1088x608.yml) | | HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | -**注意:** +**Notes:** + FairMOT enhance DLA-34 used 8 GPUs for training and mini-batch size as 16 on each GPU,and trained for 60 epoches. The crowdhuman dataset is added to the train-set during training. FairMOT enhance HarDNet-85 used 8 GPUs for training and mini-batch size as 10 on each GPU,and trained for 30 epoches. The crowdhuman dataset is added to the train-set during training. diff --git a/configs/mot/README_cn.md b/configs/mot/README_cn.md index b03dc5cb4851e30f433e6a8d8519a247b321e55d..16b8775547eaea3b1269c2aca91e49c5c8eed0b4 100644 --- a/configs/mot/README_cn.md +++ b/configs/mot/README_cn.md @@ -140,14 +140,17 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip ### 在MOT-16 Test Set上结果 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.9 | 74.7 | 1021 | 11425 | 31475 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_60e_1088x608.pdparams) | [配置文件](./fairmot_enhance_dla34_60e_1088x608.yml) | | HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | ### 在MOT-17 Test Set上结果 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.3 | 74.2 | 3270 | 29112 | 106749 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_60e_1088x608.pdparams) | [配置文件](./fairmot_enhance_dla34_60e_1088x608.yml) | | HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | **注意:** + FairMOT enhance DLA-34使用8个GPU进行训练,每个GPU上batch size为16,训练60个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。 FairMOT enhance HarDNet-85 使用8个GPU进行训练,每个GPU上batch size为10,训练30个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。 diff --git a/configs/mot/fairmot/README.md b/configs/mot/fairmot/README.md index cbef7a44114bc91f740b510a6c276b343725c226..353a9fce88b106998f983ea80a3eb96d3b3187cc 100644 --- a/configs/mot/fairmot/README.md +++ b/configs/mot/fairmot/README.md @@ -41,14 +41,17 @@ English | [简体中文](README_cn.md) ### Results on MOT-16 Test Set | backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.9 | 74.7 | 1021 | 11425 | 31475 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_60e_1088x608.pdparams) | [config](./fairmot_enhance_dla34_60e_1088x608.yml) | | HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot_enhance_hardnet85_30e_1088x608.yml) | ### Results on MOT-17 Test Set | backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.3 | 74.2 | 3270 | 29112 | 106749 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_60e_1088x608.pdparams) | [config](./fairmot_enhance_dla34_60e_1088x608.yml) | | HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot_enhance_hardnet85_30e_1088x608.yml) | -**注意:** +**Notes:** + FairMOT enhance DLA-34 used 8 GPUs for training and mini-batch size as 16 on each GPU,and trained for 60 epoches. The crowdhuman dataset is added to the train-set during training. FairMOT enhance HarDNet-85 used 8 GPUs for training and mini-batch size as 10 on each GPU,and trained for 30 epoches. The crowdhuman dataset is added to the train-set during training. @@ -64,7 +67,7 @@ English | [简体中文](README_cn.md) | HRNetV2-W18 | 1088x608 | 70.7 | 65.7 | 4281 | 22485 | 138468 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_hrnetv2_w18_dlafpn_30e_1088x608.pdparams) | [config](./fairmot_hrnetv2_w18_dlafpn_30e_1088x608.yml) | **Notes:** - FairMOT HRNetV2-W18 used 8 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches. Only ImageNet pre-train model is used, and the optimizer adopts Momentum. The crowdhuman dataset is added to the train-set during training. + FairMOT HRNetV2-W18 used 8 GPUs for training and mini-batch size as 4 on each GPU, and trained for 30 epoches. Only ImageNet pre-train model is used, and the optimizer adopts Momentum. The crowdhuman dataset is added to the train-set during training. ## Getting Start diff --git a/configs/mot/fairmot/README_cn.md b/configs/mot/fairmot/README_cn.md index c7a5a856ae092b21ff1e982f709d2390f066fea1..e85ebcc45813f178b44fc097b0668e2954374101 100644 --- a/configs/mot/fairmot/README_cn.md +++ b/configs/mot/fairmot/README_cn.md @@ -40,14 +40,17 @@ ### 在MOT-16 Test Set上结果 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.9 | 74.7 | 1021 | 11425 | 31475 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_60e_1088x608.pdparams) | [配置文件](./fairmot_enhance_dla34_60e_1088x608.yml) | | HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot_enhance_hardnet85_30e_1088x608.yml) | ### 在MOT-17 Test Set上结果 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| DLA-34 | 1088x608 | 75.3 | 74.2 | 3270 | 29112 | 106749 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_dla34_60e_1088x608.pdparams) | [配置文件](./fairmot_enhance_dla34_60e_1088x608.yml) | | HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot_enhance_hardnet85_30e_1088x608.yml) | **注意:** + FairMOT enhance DLA-34使用8个GPU进行训练,每个GPU上batch size为16,训练60个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。 FairMOT enhance HarDNet-85 使用8个GPU进行训练,每个GPU上batch size为10,训练30个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。 diff --git a/configs/mot/fairmot/_base_/fairmot_dla34.yml b/configs/mot/fairmot/_base_/fairmot_dla34.yml index 98b7f9ea1b0f7a7ce37fdee5898c1ca524f5c70f..e9959c120fcb9ac38758971f90e800db73e7af08 100644 --- a/configs/mot/fairmot/_base_/fairmot_dla34.yml +++ b/configs/mot/fairmot/_base_/fairmot_dla34.yml @@ -14,8 +14,31 @@ CenterNet: head: CenterNetHead post_process: CenterNetPostProcess +CenterNetDLAFPN: + down_ratio: 4 + last_level: 5 + out_channel: 0 + dcn_v2: True + with_sge: False + +CenterNetHead: + head_planes: 256 + heatmap_weight: 1 + regress_ltrb: True + size_weight: 0.1 + size_loss: 'L1' + offset_weight: 1 + iou_weight: 0 + +FairMOTEmbeddingHead: + ch_head: 256 + ch_emb: 128 + num_identifiers: 14455 # for mix dataset (Caltech, CityPersons, CUHK-SYSU, PRW, ETHZ and MOT16) + CenterNetPostProcess: max_per_img: 500 + down_ratio: 4 + regress_ltrb: True JDETracker: conf_thres: 0.4 diff --git a/configs/mot/fairmot/_base_/fairmot_enhance_hardnet85.yml b/configs/mot/fairmot/_base_/fairmot_hardnet85.yml similarity index 100% rename from configs/mot/fairmot/_base_/fairmot_enhance_hardnet85.yml rename to configs/mot/fairmot/_base_/fairmot_hardnet85.yml diff --git a/configs/mot/fairmot/fairmot_enhance_dla34_60e_1088x608.yml b/configs/mot/fairmot/fairmot_enhance_dla34_60e_1088x608.yml new file mode 100644 index 0000000000000000000000000000000000000000..0d8ea6afd47859ce00fb80bfb95fdd29b5850c1f --- /dev/null +++ b/configs/mot/fairmot/fairmot_enhance_dla34_60e_1088x608.yml @@ -0,0 +1,49 @@ +_BASE_: [ + '../../datasets/mot.yml', + '../../runtime.yml', + '_base_/optimizer_30e.yml', + '_base_/fairmot_dla34.yml', + '_base_/fairmot_reader_1088x608.yml', +] +norm_type: sync_bn +use_ema: true +ema_decay: 0.9998 + +worker_num: 4 +TrainReader: + inputs_def: + image_shape: [3, 608, 1088] + sample_transforms: + - Decode: {} + - RGBReverse: {} + - AugmentHSV: {} + - LetterBoxResize: {target_size: [608, 1088]} + - MOTRandomAffine: {reject_outside: False} + - RandomFlip: {} + - BboxXYXY2XYWH: {} + - NormalizeBox: {} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} + - RGBReverse: {} + - Permute: {} + batch_transforms: + - Gt2FairMOTTarget: {} + batch_size: 16 + shuffle: True + drop_last: True + use_shared_memory: True + +epoch: 60 +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [40,] + use_warmup: False + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: NULL + +weights: output/fairmot_enhance_dla34_60e_1088x608/model_final diff --git a/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml b/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml index 77936b3c39a52196c4bbb8ab0e25feab5a7c3c49..1a7228cca8230d738b4510e1f4d071161a223bd3 100644 --- a/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml +++ b/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml @@ -2,7 +2,7 @@ _BASE_: [ '../../datasets/mot.yml', '../../runtime.yml', '_base_/optimizer_30e.yml', - '_base_/fairmot_enhance_hardnet85.yml', + '_base_/fairmot_hardnet85.yml', '_base_/fairmot_reader_1088x608.yml', ] norm_type: sync_bn diff --git a/ppdet/modeling/architectures/fairmot.py b/ppdet/modeling/architectures/fairmot.py index 712ce3d6f3779d308850581cfc50d83cf75f68ac..e49c32f8c1f59d1ca40cc609a8aeb77ffeb2cffd 100755 --- a/ppdet/modeling/architectures/fairmot.py +++ b/ppdet/modeling/architectures/fairmot.py @@ -70,8 +70,8 @@ class FairMOT(BaseArch): def _forward(self): loss = dict() # det_outs keys: - # train: det_loss, heatmap_loss, size_loss, offset_loss, neck_feat - # eval/infer: bbox, bbox_inds, neck_feat + # train: neck_feat, det_loss, heatmap_loss, size_loss, offset_loss (optional: iou_loss) + # eval/infer: neck_feat, bbox, bbox_inds det_outs = self.detector(self.inputs) neck_feat = det_outs['neck_feat'] if self.training: @@ -79,12 +79,11 @@ class FairMOT(BaseArch): det_loss = det_outs['det_loss'] loss = self.loss(det_loss, reid_loss) - loss.update({ - 'heatmap_loss': det_outs['heatmap_loss'], - 'size_loss': det_outs['size_loss'], - 'offset_loss': det_outs['offset_loss'], - 'reid_loss': reid_loss - }) + for k, v in det_outs.items(): + if 'loss' not in k: + continue + loss.update({k: v}) + loss.update({'reid_loss': reid_loss}) return loss else: embedding = self.reid(neck_feat, self.inputs) diff --git a/ppdet/modeling/heads/centernet_head.py b/ppdet/modeling/heads/centernet_head.py index 2d65b08767e3855e71502011ce046c879075e99c..3454542e31a6e38cd031f3d38848acb8ac93991b 100755 --- a/ppdet/modeling/heads/centernet_head.py +++ b/ppdet/modeling/heads/centernet_head.py @@ -18,7 +18,7 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.initializer import Constant, Uniform from ppdet.core.workspace import register -from ppdet.modeling.losses import CTFocalLoss +from ppdet.modeling.losses import CTFocalLoss, GIoULoss class ConvLayer(nn.Layer): @@ -51,7 +51,6 @@ class ConvLayer(nn.Layer): def forward(self, inputs): out = self.conv(inputs) - return out @@ -66,8 +65,9 @@ class CenterNetHead(nn.Layer): regress_ltrb (bool): whether to regress left/top/right/bottom or width/height for a box, true by default size_weight (float): the weight of box size loss, 0.1 by default. + size_loss (): the type of size regression loss, 'L1 loss' by default. offset_weight (float): the weight of center offset loss, 1 by default. - + iou_weight (float): the weight of iou head loss, 0 by default. """ __shared__ = ['num_classes'] @@ -79,13 +79,18 @@ class CenterNetHead(nn.Layer): heatmap_weight=1, regress_ltrb=True, size_weight=0.1, - offset_weight=1): + size_loss='L1', + offset_weight=1, + iou_weight=0): super(CenterNetHead, self).__init__() self.weights = { 'heatmap': heatmap_weight, 'size': size_weight, - 'offset': offset_weight + 'offset': offset_weight, + 'iou': iou_weight } + + # heatmap head self.heatmap = nn.Sequential( ConvLayer( in_channels, head_planes, kernel_size=3, padding=1, bias=True), @@ -99,6 +104,8 @@ class CenterNetHead(nn.Layer): bias=True)) with paddle.no_grad(): self.heatmap[2].conv.bias[:] = -2.19 + + # size(ltrb or wh) head self.size = nn.Sequential( ConvLayer( in_channels, head_planes, kernel_size=3, padding=1, bias=True), @@ -110,13 +117,33 @@ class CenterNetHead(nn.Layer): stride=1, padding=0, bias=True)) + self.size_loss = size_loss + + # offset head self.offset = nn.Sequential( ConvLayer( in_channels, head_planes, kernel_size=3, padding=1, bias=True), nn.ReLU(), ConvLayer( head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True)) - self.focal_loss = CTFocalLoss() + + # iou head (optinal) + if iou_weight > 0: + self.iou = nn.Sequential( + ConvLayer( + in_channels, + head_planes, + kernel_size=3, + padding=1, + bias=True), + nn.ReLU(), + ConvLayer( + head_planes, + 4 if regress_ltrb else 2, + kernel_size=1, + stride=1, + padding=0, + bias=True)) @classmethod def from_config(cls, cfg, input_shape): @@ -128,22 +155,29 @@ class CenterNetHead(nn.Layer): heatmap = self.heatmap(feat) size = self.size(feat) offset = self.offset(feat) + iou = self.iou(feat) if hasattr(self, 'iou_weight') else None + if self.training: - loss = self.get_loss(heatmap, size, offset, self.weights, inputs) + loss = self.get_loss( + inputs, self.weights, heatmap, size, offset, iou=iou) return loss else: heatmap = F.sigmoid(heatmap) - return {'heatmap': heatmap, 'size': size, 'offset': offset} + head_outs = {'heatmap': heatmap, 'size': size, 'offset': offset} + if iou is not None: + head_outs.update({'iou': iou}) + return head_outs - def get_loss(self, heatmap, size, offset, weights, inputs): + def get_loss(self, inputs, weights, heatmap, size, offset, iou=None): + # heatmap head loss: CTFocalLoss heatmap_target = inputs['heatmap'] - size_target = inputs['size'] - offset_target = inputs['offset'] - index = inputs['index'] - mask = inputs['index_mask'] heatmap = paddle.clip(F.sigmoid(heatmap), 1e-4, 1 - 1e-4) - heatmap_loss = self.focal_loss(heatmap, heatmap_target) + ctfocal_loss = CTFocalLoss() + heatmap_loss = ctfocal_loss(heatmap, heatmap_target) + # size head loss: L1 loss or GIoU loss + index = inputs['index'] + mask = inputs['index_mask'] size = paddle.transpose(size, perm=[0, 2, 3, 1]) size_n, size_h, size_w, size_c = size.shape size = paddle.reshape(size, shape=[size_n, -1, size_c]) @@ -161,11 +195,32 @@ class CenterNetHead(nn.Layer): size_mask = paddle.cast(size_mask, dtype=pos_size.dtype) pos_num = size_mask.sum() size_mask.stop_gradient = True - size_target.stop_gradient = True - size_loss = F.l1_loss( - pos_size * size_mask, size_target * size_mask, reduction='sum') - size_loss = size_loss / (pos_num + 1e-4) + if self.size_loss == 'L1': + size_target = inputs['size'] + size_target.stop_gradient = True + size_loss = F.l1_loss( + pos_size * size_mask, size_target * size_mask, reduction='sum') + size_loss = size_loss / (pos_num + 1e-4) + elif self.size_loss == 'giou': + size_target = inputs['bbox_xys'] + size_target.stop_gradient = True + centers_x = (size_target[:, :, 0:1] + size_target[:, :, 2:3]) / 2.0 + centers_y = (size_target[:, :, 1:2] + size_target[:, :, 3:4]) / 2.0 + x1 = centers_x - pos_size[:, :, 0:1] + y1 = centers_y - pos_size[:, :, 1:2] + x2 = centers_x + pos_size[:, :, 2:3] + y2 = centers_y + pos_size[:, :, 3:4] + pred_boxes = paddle.concat([x1, y1, x2, y2], axis=-1) + giou_loss = GIoULoss(reduction='sum') + size_loss = giou_loss( + pred_boxes * size_mask, + size_target * size_mask, + iou_weight=size_mask, + loc_reweight=None) + size_loss = size_loss / (pos_num + 1e-4) + # offset head loss: L1 loss + offset_target = inputs['offset'] offset = paddle.transpose(offset, perm=[0, 2, 3, 1]) offset_n, offset_h, offset_w, offset_c = offset.shape offset = paddle.reshape(offset, shape=[offset_n, -1, offset_c]) @@ -181,12 +236,43 @@ class CenterNetHead(nn.Layer): reduction='sum') offset_loss = offset_loss / (pos_num + 1e-4) - det_loss = weights['heatmap'] * heatmap_loss + weights[ - 'size'] * size_loss + weights['offset'] * offset_loss + # iou head loss: GIoU loss + if iou is not None: + iou = paddle.transpose(iou, perm=[0, 2, 3, 1]) + iou_n, iou_h, iou_w, iou_c = iou.shape + iou = paddle.reshape(iou, shape=[iou_n, -1, iou_c]) + pos_iou = paddle.gather_nd(iou, index=index) + iou_mask = paddle.expand_as(mask, pos_iou) + iou_mask = paddle.cast(iou_mask, dtype=pos_iou.dtype) + pos_num = iou_mask.sum() + iou_mask.stop_gradient = True + gt_bbox_xys = inputs['bbox_xys'] + gt_bbox_xys.stop_gradient = True + centers_x = (gt_bbox_xys[:, :, 0:1] + gt_bbox_xys[:, :, 2:3]) / 2.0 + centers_y = (gt_bbox_xys[:, :, 1:2] + gt_bbox_xys[:, :, 3:4]) / 2.0 + x1 = centers_x - pos_size[:, :, 0:1] + y1 = centers_y - pos_size[:, :, 1:2] + x2 = centers_x + pos_size[:, :, 2:3] + y2 = centers_y + pos_size[:, :, 3:4] + pred_boxes = paddle.concat([x1, y1, x2, y2], axis=-1) + giou_loss = GIoULoss(reduction='sum') + iou_loss = giou_loss( + pred_boxes * iou_mask, + gt_bbox_xys * iou_mask, + iou_weight=iou_mask, + loc_reweight=None) + iou_loss = iou_loss / (pos_num + 1e-4) - return { - 'det_loss': det_loss, + losses = { 'heatmap_loss': heatmap_loss, 'size_loss': size_loss, - 'offset_loss': offset_loss + 'offset_loss': offset_loss, } + det_loss = weights['heatmap'] * heatmap_loss + weights[ + 'size'] * size_loss + weights['offset'] * offset_loss + + if iou is not None: + losses.update({'iou_loss': iou_loss}) + det_loss = det_loss + weights['iou'] * iou_loss + losses.update({'det_loss': det_loss}) + return losses diff --git a/ppdet/modeling/necks/centernet_fpn.py b/ppdet/modeling/necks/centernet_fpn.py index fd40cf6e33152faa39713fb366f9f1adf4895c65..1ca1a4b58f165594f14264402579467da74c93eb 100755 --- a/ppdet/modeling/necks/centernet_fpn.py +++ b/ppdet/modeling/necks/centernet_fpn.py @@ -27,6 +27,74 @@ from ..shape_spec import ShapeSpec __all__ = ['CenterNetDLAFPN', 'CenterNetHarDNetFPN'] +# SGE attention +class BasicConv(nn.Layer): + def __init__(self, + in_planes, + out_planes, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + relu=True, + bn=True, + bias_attr=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2D( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias_attr) + self.bn = nn.BatchNorm2D( + out_planes, + epsilon=1e-5, + momentum=0.01, + weight_attr=False, + bias_attr=False) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class ChannelPool(nn.Layer): + def forward(self, x): + return paddle.concat( + (paddle.max(x, 1).unsqueeze(1), paddle.mean(x, 1).unsqueeze(1)), + axis=1) + + +class SpatialGate(nn.Layer): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv( + 2, + 1, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale + + def fill_up_weights(up): weight = up.weight.numpy() f = math.ceil(weight.shape[2] / 2) @@ -145,10 +213,10 @@ class CenterNetDLAFPN(nn.Layer): last_level (int): the last level of input feature fed into the upsamplng block out_channel (int): the channel of the output feature, 0 by default means the channel of the input feature whose down ratio is `down_ratio` - dcn_v2 (bool): whether use the DCNv2, true by default - first_level (int|None): the first level of input feature fed into the upsamplng block. + first_level (None): the first level of input feature fed into the upsamplng block. if None, the first level stands for logs(down_ratio) - + dcn_v2 (bool): whether use the DCNv2, True by default + with_sge (bool): whether use SGE attention, False by default """ def __init__(self, @@ -156,8 +224,9 @@ class CenterNetDLAFPN(nn.Layer): down_ratio=4, last_level=5, out_channel=0, + first_level=None, dcn_v2=True, - first_level=None): + with_sge=False): super(CenterNetDLAFPN, self).__init__() self.first_level = int(np.log2( down_ratio)) if first_level is None else first_level @@ -180,6 +249,10 @@ class CenterNetDLAFPN(nn.Layer): [2**i for i in range(self.last_level - self.first_level)], dcn_v2=dcn_v2) + self.with_sge = with_sge + if self.with_sge: + self.sge_attention = SpatialGate() + @classmethod def from_config(cls, cfg, input_shape): return {'in_channels': [i.channels for i in input_shape]} @@ -194,7 +267,10 @@ class CenterNetDLAFPN(nn.Layer): self.ida_up(ida_up_feats, 0, len(ida_up_feats)) - return ida_up_feats[-1] + feat = ida_up_feats[-1] + if self.with_sge: + feat = self.sge_attention(feat) + return feat @property def out_shape(self):