diff --git a/configs/datasets/mot.yml b/configs/datasets/mot.yml index b5966f025e1465b64f6a9b746ba00d6c7f350ba7..c716d1c23e7fe76979c427f58c22d58f667048ae 100644 --- a/configs/datasets/mot.yml +++ b/configs/datasets/mot.yml @@ -12,7 +12,7 @@ MOTDataZoo: { TrainDataset: !MOTDataSet dataset_dir: dataset/mot - image_lists: ['mot17.train', 'caltech.train', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train'] + image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train'] data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] EvalDataset: diff --git a/configs/mot/fairmot/README.md b/configs/mot/fairmot/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ddea059353d0883ec31e05dfae312a1d385ec74 --- /dev/null +++ b/configs/mot/fairmot/README.md @@ -0,0 +1,66 @@ +# FairMOT (FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking) + +## Table of Contents +- [Introduction](#Introduction) +- [Model Zoo](#Model_Zoo) +- [Getting Start](#Getting_Start) +- [Citations](#Citations) + +## Introduction + +FairMOT focuses on accomplishing the detection and re-identification in a single network to improve the inference speed, presents a simple baseline which consists of two homogeneous branches to predict pixel-wise objectness scores and re-ID features. The achieved fairness between the two tasks allows FairMOT to obtain high levels of detection and tracking accuracy. + + +## Model Zoo + +### Results on MOT-16 train set + +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | download | config | +| :-----------------| :------- | :----: | :----: | :---: | :----: | :---: |:---: | :---: | +| DLA-34(paper) | 1088x608 | 83.3 | 81.9 | 544 | 3822 | 14095 | ---- | ---- | +| DLA-34 | 1088x608 | 83.4 | 82.7 | 517 | 4077 | 13761 | [model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml) | + + +### Results on MOT-16 test set + +| backbone | input shape | MOTA | IDF1 | IDS | MT | ML | download | config | +| :-----------------| :------- | :----: | :----: | :---: | :----: | :---: | :---: | :---: | +| DLA-34(paper) | 1088x608 | 74.9 72.8 1074 44.7% 15.9% | ---- | ---- | +| DLA-34 | 1088x608 | 74.7 | 72.8 | 1044 | 41.9% | 19.1% |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml) | + +**Notes:** + +FairMOT used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches. + +## Getting Start + +### 1. Training + +Training FairMOT on 2 GPUs with following command + +```bash +python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608/ --gpus 0,1 tools/train.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml &>fairmot_dla34_30e_1088x608.log 2>&1 & +``` + + +### 2. Evaluation + +Evaluating the track performance of FairMOT on val dataset in single GPU with following commands: + +```bash +# use weights released in PaddleDetection model zoo +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams + +# use saved checkpoint in training +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=output/fairmot_dla34_30e_1088x608/model_final +``` + +## Citations +``` +@article{zhang2020fair, + title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking}, + author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu}, + journal={arXiv preprint arXiv:2004.01888}, + year={2020} +} +``` diff --git a/configs/mot/fairmot/_base_/fairmot_dla34.yml b/configs/mot/fairmot/_base_/fairmot_dla34.yml new file mode 100644 index 0000000000000000000000000000000000000000..c5f07de702fbeb594c9eeda60d709c0c40af8b1b --- /dev/null +++ b/configs/mot/fairmot/_base_/fairmot_dla34.yml @@ -0,0 +1,23 @@ +architecture: FairMOT +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams + +FairMOT: + detector: CenterNet + reid: FairMOTEmbeddingHead + loss: FairMOTLoss + tracker: JDETracker + +CenterNet: + backbone: DLA + neck: CenterNetDLAFPN + head: CenterNetHead + post_process: CenterNetPostProcess + for_mot: True + +CenterNetPostProcess: + for_mot: True + +JDETracker: + conf_thres: 0.4 + tracked_thresh: 0.4 + metric_type: cosine diff --git a/configs/mot/fairmot/_base_/fairmot_reader_1088x608.yml b/configs/mot/fairmot/_base_/fairmot_reader_1088x608.yml new file mode 100644 index 0000000000000000000000000000000000000000..6dadac3a32f8fe46a06d73374ce147e553a21db2 --- /dev/null +++ b/configs/mot/fairmot/_base_/fairmot_reader_1088x608.yml @@ -0,0 +1,30 @@ +worker_num: 4 +TrainReader: + inputs_def: + image_shape: [3, 608, 1088] + sample_transforms: + - Decode: {to_rgb: False} + - AugmentHSV: {is_bgr: True} + - LetterBoxResize: {target_size: [608, 1088]} + - MOTRandomAffine: {reject_outside: False} + - RandomFlip: {} + - BboxXYXY2XYWH: {} + - NormalizeBox: {} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} + - Permute: {to_rgb: True} + batch_transforms: + - Gt2FairMOTTarget: {} + batch_size: 6 + shuffle: True + drop_last: True + + +EvalMOTReader: + inputs_def: + image_shape: [3, 608, 1088] + sample_transforms: + - Decode: {to_rgb: False} + - LetterBoxResize: {target_size: [608, 1088]} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} + - Permute: {to_rgb: True} + batch_size: 1 diff --git a/configs/mot/fairmot/_base_/optimizer_30e.yml b/configs/mot/fairmot/_base_/optimizer_30e.yml new file mode 100644 index 0000000000000000000000000000000000000000..6e7ec0dc45e9180cf0e632bd19d0de66d619ec7d --- /dev/null +++ b/configs/mot/fairmot/_base_/optimizer_30e.yml @@ -0,0 +1,14 @@ +epoch: 30 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [20,] + use_warmup: False + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: NULL diff --git a/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml b/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml new file mode 100644 index 0000000000000000000000000000000000000000..8c1e708c01a9b0b0b32d16d9605095e04a5ff176 --- /dev/null +++ b/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml @@ -0,0 +1,10 @@ +_BASE_: [ + '../../datasets/mot.yml', + '../../runtime.yml', + '_base_/optimizer_30e.yml', + '_base_/fairmot_dla34.yml', + '_base_/fairmot_reader_1088x608.yml', +] + +metric: MOT +weights: output/fairmot_dla34_30e_1088x608/model_final diff --git a/ppdet/data/source/mot.py b/ppdet/data/source/mot.py index ba7e9d20552e89199bbfde536928b43afaab751a..a8a11dc91d7f8d974d71b3207fd8686c03839625 100644 --- a/ppdet/data/source/mot.py +++ b/ppdet/data/source/mot.py @@ -347,6 +347,7 @@ class MOTVideoDataset(DetDataset): def _load_video_images(self): self.cap = cv2.VideoCapture(self.video_file) self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS))) logger.info('Length of the video: {:d} frames'.format(self.vn)) res = True ct = 0 diff --git a/ppdet/data/transform/mot_operators.py b/ppdet/data/transform/mot_operators.py index 027804b1695b17054b7da37cb67e400d76df72da..bcd5c430473b7ff46df0c8ee72b76330135e5a17 100644 --- a/ppdet/data/transform/mot_operators.py +++ b/ppdet/data/transform/mot_operators.py @@ -25,14 +25,19 @@ from numbers import Integral import cv2 import copy import numpy as np +import math from .operators import BaseOperator, register_op +from .batch_operators import Gt2TTFTarget from ppdet.modeling.bbox_utils import bbox_iou_np_expand from ppdet.core.workspace import serializable from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) -__all__ = ['LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax'] +__all__ = [ + 'LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax', + 'Gt2FairMOTTarget' +] @register_op @@ -367,3 +372,117 @@ class Gt2JDETargetMax(BaseOperator): sample['tbox{}'.format(i)] = tbox sample['tconf{}'.format(i)] = tconf sample['tide{}'.format(i)] = tid + + +class Gt2FairMOTTarget(Gt2TTFTarget): + __shared__ = ['num_classes'] + """ + Generate FairMOT targets by ground truth data. + Difference between Gt2FairMOTTarget and Gt2TTFTarget are: + 1. the gaussian kernal radius to generate a heatmap. + 2. the targets needed during traing. + + Args: + num_classes(int): the number of classes. + down_ratio(int): the down ratio from images to heatmap, 4 by default. + max_objs(int): the maximum number of ground truth objects in a image, 500 by default. + """ + + def __init__(self, num_classes=1, down_ratio=4, max_objs=500): + super(Gt2TTFTarget, self).__init__() + self.down_ratio = down_ratio + self.num_classes = num_classes + self.max_objs = max_objs + + def __call__(self, samples, context=None): + for b_id, sample in enumerate(samples): + output_h = sample['image'].shape[1] // self.down_ratio + output_w = sample['image'].shape[2] // self.down_ratio + + heatmap = np.zeros( + (self.num_classes, output_h, output_w), dtype='float32') + bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32) + center_offset = np.zeros((self.max_objs, 2), dtype=np.float32) + index = np.zeros((self.max_objs, ), dtype=np.int64) + index_mask = np.zeros((self.max_objs, ), dtype=np.int32) + reid = np.zeros((self.max_objs, ), dtype=np.int64) + bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32) + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + gt_ide = sample['gt_ide'] + + for k in range(len(gt_bbox)): + cls_id = gt_class[k][0] + bbox = gt_bbox[k] + ide = gt_ide[k][0] + bbox[[0, 2]] = bbox[[0, 2]] * output_w + bbox[[1, 3]] = bbox[[1, 3]] * output_h + bbox_amodal = copy.deepcopy(bbox) + bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2. + bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2. + bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2] + bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3] + bbox[0] = np.clip(bbox[0], 0, output_w - 1) + bbox[1] = np.clip(bbox[1], 0, output_h - 1) + h = bbox[3] + w = bbox[2] + + bbox_xy = copy.deepcopy(bbox) + bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2 + bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2 + bbox_xy[2] = bbox_xy[0] + bbox_xy[2] + bbox_xy[3] = bbox_xy[1] + bbox_xy[3] + + if h > 0 and w > 0: + radius = self.gaussian_radius((math.ceil(h), math.ceil(w))) + radius = max(0, int(radius)) + ct = np.array([bbox[0], bbox[1]], dtype=np.float32) + ct_int = ct.astype(np.int32) + self.draw_truncate_gaussian(heatmap[cls_id], ct_int, radius, + radius) + bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \ + bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1] + + index[k] = ct_int[1] * output_w + ct_int[0] + center_offset[k] = ct - ct_int + index_mask[k] = 1 + reid[k] = ide + bbox_xys[k] = bbox_xy + + sample['heatmap'] = heatmap + sample['index'] = index + sample['offset'] = center_offset + sample['size'] = bbox_size + sample['index_mask'] = index_mask + sample['reid'] = reid + sample['bbox_xys'] = bbox_xys + sample.pop('is_crowd', None) + sample.pop('difficult', None) + sample.pop('gt_class', None) + sample.pop('gt_bbox', None) + sample.pop('gt_score', None) + sample.pop('gt_ide', None) + return samples + + def gaussian_radius(self, det_size, min_overlap=0.7): + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + return min(r1, r2, r3) diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index cc6218857884a5d64957eca079898787a2f8b210..d370ce083b4f415da84a80621dad996b44ac0617 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -107,10 +107,12 @@ class BaseOperator(object): @register_op class Decode(BaseOperator): - def __init__(self): + def __init__(self, to_rgb=True): """ Transform the image data to numpy format following the rgb format """ super(Decode, self).__init__() + # TODO: remove this parameter + self.to_rgb = to_rgb def apply(self, sample, context=None): """ load image if 'im_file' field is not empty but 'image' is""" @@ -124,7 +126,8 @@ class Decode(BaseOperator): im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode if 'keep_ori_im' in sample and sample['keep_ori_im']: sample['ori_image'] = im - im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + if self.to_rgb: + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) sample['image'] = im if 'h' not in sample: @@ -151,14 +154,18 @@ class Decode(BaseOperator): @register_op class Permute(BaseOperator): - def __init__(self): + def __init__(self, to_rgb=False): """ Change the channel to be (C, H, W) """ super(Permute, self).__init__() + # TODO: remove this parameter + self.to_rgb = to_rgb def apply(self, sample, context=None): im = sample['image'] + if self.to_rgb: + im = np.ascontiguousarray(im[:, :, ::-1]) im = im.transpose((2, 0, 1)) sample['image'] = im return sample @@ -2076,29 +2083,39 @@ class Norm2PixelBbox(BaseOperator): @register_op class MOTRandomAffine(BaseOperator): + """ + Affine transform to image and coords to achieve the rotate, scale and + shift effect for training image. + + Args: + degrees (list[2]): the rotate range to apply, transform range is [min, max] + translate (list[2]): the translate range to apply, ransform range is [min, max] + scale (list[2]): the scale range to apply, transform range is [min, max] + shear (list[2]): the shear range to apply, transform range is [min, max] + borderValue (list[3]): value used in case of a constant border when appling + the perspective transformation + reject_outside (bool): reject warped bounding bboxes outside of image + + Returns: + records(dict): contain the image and coords after tranformed + + """ + def __init__(self, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.50, 1.20), shear=(-2, 2), - borderValue=(127.5, 127.5, 127.5)): - """ - Affine transform to image and coords to achieve the rotate, scale and - shift effect for training image. + borderValue=(127.5, 127.5, 127.5), + reject_outside=True): - Args: - degrees (tuple): rotation value - translate (tuple): xy coords translation value - scale (tuple): scale value - shear (tuple): shear value - borderValue (tuple): border color value - """ super(MOTRandomAffine, self).__init__() self.degrees = degrees self.translate = translate self.scale = scale self.shear = shear self.borderValue = borderValue + self.reject_outside = reject_outside def apply(self, sample, context=None): # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 @@ -2171,10 +2188,11 @@ class MOTRandomAffine(BaseOperator): (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T # reject warped points outside of image - np.clip(xy[:, 0], 0, width, out=xy[:, 0]) - np.clip(xy[:, 2], 0, width, out=xy[:, 2]) - np.clip(xy[:, 1], 0, height, out=xy[:, 1]) - np.clip(xy[:, 3], 0, height, out=xy[:, 3]) + if self.reject_outside: + np.clip(xy[:, 0], 0, width, out=xy[:, 0]) + np.clip(xy[:, 2], 0, width, out=xy[:, 2]) + np.clip(xy[:, 1], 0, height, out=xy[:, 1]) + np.clip(xy[:, 3], 0, height, out=xy[:, 3]) w = xy[:, 2] - xy[:, 0] h = xy[:, 3] - xy[:, 1] area = w * h diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 926a909c70b58c7c59f8c239fc6eeb48e42eae77..bc90f82c4374363060dd3aa4ef979e7cb9aad2de 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -66,6 +66,10 @@ class Trainer(object): cfg['JDEEmbeddingHead'][ 'num_identifiers'] = self.dataset.total_identities + if cfg.architecture == 'FairMOT' and self.mode == 'train': + cfg['FairMOTEmbeddingHead'][ + 'num_identifiers'] = self.dataset.total_identities + # build model if 'model' not in self.cfg: self.model = create(cfg.architecture) @@ -223,7 +227,10 @@ class Trainer(object): return self.start_epoch = 0 if hasattr(self.model, 'detector'): - load_pretrain_weight(self.model.detector, weights) + if self.model.__class__.__name__ == 'FairMOT': + load_pretrain_weight(self.model, weights) + else: + load_pretrain_weight(self.model.detector, weights) else: load_pretrain_weight(self.model, weights) logger.debug("Load weights {} to start training".format(weights)) diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 50566f7c92960180153c724685d3d19367b9a892..6b2a33ae57801caa618b4e4ec22d9872d0fe8d7a 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -19,6 +19,8 @@ from . import keypoint_hrhrnet from . import keypoint_hrnet from . import jde from . import deepsort +from . import fairmot +from . import centernet from .meta_arch import * from .faster_rcnn import * @@ -34,3 +36,5 @@ from .keypoint_hrhrnet import * from .keypoint_hrnet import * from .jde import * from .deepsort import * +from .fairmot import * +from .centernet import * diff --git a/ppdet/modeling/architectures/centernet.py b/ppdet/modeling/architectures/centernet.py new file mode 100755 index 0000000000000000000000000000000000000000..719e4c20b640d3b70f42fcff96ea8ede697813aa --- /dev/null +++ b/ppdet/modeling/architectures/centernet.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['CenterNet'] + + +@register +class CenterNet(BaseArch): + """ + CenterNet network, see http://arxiv.org/abs/1904.07850 + + Args: + backbone (object): backbone instance + neck (object): 'CenterDLAFPN' instance + head (object): 'CenterHead' instance + post_process (object): 'CenterNetPostProcess' instance + for_mot (bool): whether return other features used in tracking model + + """ + __category__ = 'architecture' + __inject__ = ['post_process'] + + def __init__(self, + backbone='DLA', + neck='CenterDLAFPN', + head='CenterHead', + post_process='CenterNetPostProcess', + for_mot=False): + super(CenterNet, self).__init__() + self.backbone = backbone + self.neck = neck + self.head = head + self.post_process = post_process + self.for_mot = for_mot + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + head = create(cfg['head'], **kwargs) + + return {'backbone': backbone, 'neck': neck, "head": head} + + def _forward(self): + body_feats = self.backbone(self.inputs) + neck_feat = self.neck(body_feats) + head_out = self.head(neck_feat, self.inputs) + if self.for_mot: + head_out.update({'neck_feat': neck_feat}) + return head_out + + def get_pred(self): + head_out = self._forward() + if self.for_mot: + bbox, bbox_inds = self.post_process( + head_out['heatmap'], + head_out['size'], + head_out['offset'], + im_shape=self.inputs['im_shape'], + scale_factor=self.inputs['scale_factor']) + output = { + "bbox": bbox, + "bbox_inds": bbox_inds, + "neck_feat": head_out['neck_feat'] + } + else: + bbox, bbox_num = self.post_process( + head_out['heatmap'], + head_out['size'], + head_out['offset'], + im_shape=self.inputs['im_shape'], + scale_factor=self.inputs['scale_factor']) + output = { + "bbox": bbox, + "bbox_num": bbox_num, + } + return output + + def get_loss(self): + return self._forward() diff --git a/ppdet/modeling/architectures/fairmot.py b/ppdet/modeling/architectures/fairmot.py new file mode 100755 index 0000000000000000000000000000000000000000..1a29e3f59bf5003bbf8f28053797262361fc8323 --- /dev/null +++ b/ppdet/modeling/architectures/fairmot.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['FairMOT'] + + +@register +class FairMOT(BaseArch): + """ + FairMOT network, see http://arxiv.org/abs/2004.01888 + + Args: + detector (object): 'CenterNet' instance + reid (object): 'FairMOTEmbeddingHead' instance + tracker (object): 'JDETracker' instance + loss (object): 'FairMOTLoss' instance + + """ + + __category__ = 'architecture' + __inject__ = ['loss'] + + def __init__(self, + detector='CenterNet', + reid='FairMOTEmbeddingHead', + tracker='JDETracker', + loss='FairMOTLoss'): + super(FairMOT, self).__init__() + self.detector = detector + self.reid = reid + self.tracker = tracker + self.loss = loss + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + detector = create(cfg['detector']) + + kwargs = {'input_shape': detector.neck.out_shape} + reid = create(cfg['reid'], **kwargs) + loss = create(cfg['loss']) + tracker = create(cfg['tracker']) + + return { + 'detector': detector, + 'reid': reid, + 'loss': loss, + 'tracker': tracker + } + + def _forward(self): + loss = dict() + # det_outs keys: + # train: det_loss, heatmap_loss, size_loss, offset_loss, neck_feat + # eval/infer: bbox, bbox_inds, neck_feat + det_outs = self.detector(self.inputs) + neck_feat = det_outs['neck_feat'] + if self.training: + reid_loss = self.reid(neck_feat, self.inputs) + + det_loss = det_outs['det_loss'] + loss = self.loss(det_loss, reid_loss) + loss.update({ + 'heatmap_loss': det_outs['heatmap_loss'], + 'size_loss': det_outs['size_loss'], + 'offset_loss': det_outs['offset_loss'], + 'reid_loss': reid_loss + }) + return loss + else: + embedding = self.reid(neck_feat, self.inputs) + bbox_inds = det_outs['bbox_inds'] + embedding = paddle.transpose(embedding, [0, 2, 3, 1]) + embedding = paddle.reshape(embedding, + [-1, paddle.shape(embedding)[-1]]) + id_feature = paddle.gather(embedding, bbox_inds) + dets = det_outs['bbox'] + id_feature = id_feature + # Note: the tracker only considers batch_size=1 and num_classses=1 + online_targets = self.tracker.update(dets, id_feature) + return online_targets + + def get_pred(self): + output = self._forward() + return output + + def get_loss(self): + loss = self._forward() + return loss diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index c1eb8e85dd53c9aee8319f5ea7f457a24d55ea3c..6d66690f2faf726fde5244373a7826cf09bbf5f7 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -22,6 +22,7 @@ from . import blazenet from . import ghostnet from . import senet from . import res2net +from . import dla from .vgg import * from .resnet import * @@ -33,3 +34,4 @@ from .blazenet import * from .ghostnet import * from .senet import * from .res2net import * +from .dla import * diff --git a/ppdet/modeling/backbones/dla.py b/ppdet/modeling/backbones/dla.py new file mode 100755 index 0000000000000000000000000000000000000000..4ab06ab7f763a55232b7cc182e2e9df89e99bb88 --- /dev/null +++ b/ppdet/modeling/backbones/dla.py @@ -0,0 +1,243 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, serializable +from ppdet.modeling.layers import ConvNormLayer +from ..shape_spec import ShapeSpec + +DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512])} + + +class BasicBlock(nn.Layer): + def __init__(self, ch_in, ch_out, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = ConvNormLayer( + ch_in, + ch_out, + filter_size=3, + stride=stride, + bias_on=False, + norm_decay=None) + self.conv2 = ConvNormLayer( + ch_out, + ch_out, + filter_size=3, + stride=1, + bias_on=False, + norm_decay=None) + + def forward(self, inputs, residual=None): + if residual is None: + residual = inputs + + out = self.conv1(inputs) + out = F.relu(out) + + out = self.conv2(out) + + out = paddle.add(x=out, y=residual) + out = F.relu(out) + + return out + + +class Root(nn.Layer): + def __init__(self, ch_in, ch_out, kernel_size, residual): + super(Root, self).__init__() + self.conv = ConvNormLayer( + ch_in, + ch_out, + filter_size=1, + stride=1, + bias_on=False, + norm_decay=None) + self.residual = residual + + def forward(self, inputs): + children = inputs + out = self.conv(paddle.concat(inputs, axis=1)) + if self.residual: + out = paddle.add(x=out, y=children[0]) + out = F.relu(out) + + return out + + +class Tree(nn.Layer): + def __init__(self, + level, + block, + ch_in, + ch_out, + stride=1, + level_root=False, + root_dim=0, + root_kernel_size=1, + root_residual=False): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * ch_out + if level_root: + root_dim += ch_in + if level == 1: + self.tree1 = block(ch_in, ch_out, stride) + self.tree2 = block(ch_out, ch_out, 1) + else: + self.tree1 = Tree( + level - 1, + block, + ch_in, + ch_out, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + root_residual=root_residual) + self.tree2 = Tree( + level - 1, + block, + ch_out, + ch_out, + 1, + root_dim=root_dim + ch_out, + root_kernel_size=root_kernel_size, + root_residual=root_residual) + + if level == 1: + self.root = Root(root_dim, ch_out, root_kernel_size, root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.level = level + if stride > 1: + self.downsample = nn.MaxPool2D(stride, stride=stride) + if ch_in != ch_out: + self.project = ConvNormLayer( + ch_in, + ch_out, + filter_size=1, + stride=1, + bias_on=False, + norm_decay=None) + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.level == 1: + x2 = self.tree2(x1) + x = self.root([x2, x1] + children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +@register +@serializable +class DLA(nn.Layer): + """ + DLA, see https://arxiv.org/pdf/1707.06484.pdf + + Args: + depth (int): DLA depth, should be 34. + residual_root (bool): whether use a reidual layer in the root block + + """ + + def __init__(self, depth=34, residual_root=False): + super(DLA, self).__init__() + levels, channels = DLA_cfg[depth] + if depth == 34: + block = BasicBlock + self.channels = channels + self.base_layer = nn.Sequential( + ConvNormLayer( + 3, + channels[0], + filter_size=7, + stride=1, + bias_on=False, + norm_decay=None), + nn.ReLU()) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root) + self.level3 = Tree( + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root) + self.level4 = Tree( + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root) + self.level5 = Tree( + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root) + + def _make_conv_level(self, ch_in, ch_out, conv_num, stride=1): + modules = [] + for i in range(conv_num): + modules.extend([ + ConvNormLayer( + ch_in, + ch_out, + filter_size=3, + stride=stride if i == 0 else 1, + bias_on=False, + norm_decay=None), nn.ReLU() + ]) + ch_in = ch_out + return nn.Sequential(*modules) + + @property + def out_shape(self): + return [ShapeSpec(channels=self.channels[i]) for i in range(6)] + + def forward(self, inputs): + outs = [] + im = inputs['image'] + feats = self.base_layer(im) + for i in range(6): + feats = getattr(self, 'level{}'.format(i))(feats) + outs.append(feats) + + return outs diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index a6dfcabdaf5e0a5ce69124b65cca4339f8b62aec..04be00e9eba6b5e65f229af88ba82a6c4f613dbc 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -24,6 +24,7 @@ from . import cascade_head from . import face_head from . import s2anet_head from . import keypoint_hrhrnet_head +from . import centernet_head from .bbox_head import * from .mask_head import * @@ -37,3 +38,4 @@ from .cascade_head import * from .face_head import * from .s2anet_head import * from .keypoint_hrhrnet_head import * +from .centernet_head import * diff --git a/ppdet/modeling/heads/centernet_head.py b/ppdet/modeling/heads/centernet_head.py new file mode 100755 index 0000000000000000000000000000000000000000..00ac01ec2da416a0849901f32347d4eef76c6c34 --- /dev/null +++ b/ppdet/modeling/heads/centernet_head.py @@ -0,0 +1,194 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import KaimingUniform +from ppdet.core.workspace import register +from ppdet.modeling.losses import CTFocalLoss + + +class ConvLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + super(ConvLayer, self).__init__() + bias_attr = False + fan_in = ch_in * kernel_size**2 + bound = 1 / math.sqrt(fan_in) + param_attr = paddle.ParamAttr(initializer=KaimingUniform()) + if bias: + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.Uniform(-bound, bound)) + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + weight_attr=param_attr, + bias_attr=bias_attr) + + def forward(self, inputs): + out = self.conv(inputs) + + return out + + +@register +class CenterNetHead(nn.Layer): + """ + Args: + in_channels (int): the channel number of input to CenterNetHead. + num_classes (int): the number of classes, 80 by default. + head_planes (int): the channel number in all head, 256 by default. + heatmap_weight (float): the weight of heatmap loss, 1 by default. + regress_ltrb (bool): whether to regress left/top/right/bottom or + width/height for a box, true by default + size_weight (float): the weight of box size loss, 0.1 by default. + offset_weight (float): the weight of center offset loss, 1 by default. + + """ + + __shared__ = ['num_classes'] + + def __init__(self, + in_channels, + num_classes=80, + head_planes=256, + heatmap_weight=1, + regress_ltrb=True, + size_weight=0.1, + offset_weight=1): + super(CenterNetHead, self).__init__() + self.weights = { + 'heatmap': heatmap_weight, + 'size': size_weight, + 'offset': offset_weight + } + self.heatmap = nn.Sequential( + ConvLayer( + in_channels, head_planes, kernel_size=3, padding=1, bias=True), + nn.ReLU(), + ConvLayer( + head_planes, + num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + self.heatmap[2].conv.bias[:] = -2.19 + self.size = nn.Sequential( + ConvLayer( + in_channels, head_planes, kernel_size=3, padding=1, bias=True), + nn.ReLU(), + ConvLayer( + head_planes, + 4 if regress_ltrb else 2, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + self.offset = nn.Sequential( + ConvLayer( + in_channels, head_planes, kernel_size=3, padding=1, bias=True), + nn.ReLU(), + ConvLayer( + head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True)) + self.focal_loss = CTFocalLoss() + + @classmethod + def from_config(cls, cfg, input_shape): + if isinstance(input_shape, (list, tuple)): + input_shape = input_shape[0] + return {'in_channels': input_shape.channels} + + def forward(self, feat, inputs): + heatmap = self.heatmap(feat) + size = self.size(feat) + offset = self.offset(feat) + if self.training: + loss = self.get_loss(heatmap, size, offset, self.weights, inputs) + return loss + else: + heatmap = F.sigmoid(heatmap) + return {'heatmap': heatmap, 'size': size, 'offset': offset} + + def get_loss(self, heatmap, size, offset, weights, inputs): + heatmap_target = inputs['heatmap'] + size_target = inputs['size'] + offset_target = inputs['offset'] + index = inputs['index'] + mask = inputs['index_mask'] + heatmap = paddle.clip(F.sigmoid(heatmap), 1e-4, 1 - 1e-4) + heatmap_loss = self.focal_loss(heatmap, heatmap_target) + + size = paddle.transpose(size, perm=[0, 2, 3, 1]) + size_n, size_h, size_w, size_c = size.shape + size = paddle.reshape(size, shape=[size_n, -1, size_c]) + index = paddle.unsqueeze(index, 2) + batch_inds = list() + for i in range(size_n): + batch_ind = paddle.full( + shape=[1, index.shape[1], 1], fill_value=i, dtype='int64') + batch_inds.append(batch_ind) + batch_inds = paddle.concat(batch_inds, axis=0) + index = paddle.concat(x=[batch_inds, index], axis=2) + pos_size = paddle.gather_nd(size, index=index) + mask = paddle.unsqueeze(mask, axis=2) + size_mask = paddle.expand_as(mask, pos_size) + size_mask = paddle.cast(size_mask, dtype=pos_size.dtype) + pos_num = size_mask.sum() + size_mask.stop_gradient = True + size_target.stop_gradient = True + size_loss = F.l1_loss( + pos_size * size_mask, size_target * size_mask, reduction='sum') + size_loss = size_loss / (pos_num + 1e-4) + + offset = paddle.transpose(offset, perm=[0, 2, 3, 1]) + offset_n, offset_h, offset_w, offset_c = offset.shape + offset = paddle.reshape(offset, shape=[offset_n, -1, offset_c]) + pos_offset = paddle.gather_nd(offset, index=index) + offset_mask = paddle.expand_as(mask, pos_offset) + offset_mask = paddle.cast(offset_mask, dtype=pos_offset.dtype) + pos_num = offset_mask.sum() + offset_mask.stop_gradient = True + offset_target.stop_gradient = True + offset_loss = F.l1_loss( + pos_offset * offset_mask, + offset_target * offset_mask, + reduction='sum') + offset_loss = offset_loss / (pos_num + 1e-4) + + det_loss = weights['heatmap'] * heatmap_loss + weights[ + 'size'] * size_loss + weights['offset'] * offset_loss + + return { + 'det_loss': det_loss, + 'heatmap_loss': heatmap_loss, + 'size_loss': size_loss, + 'offset_loss': offset_loss + } diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index b420d51fe27a31446263c57001d45cd4fd5226a9..75e554559fb1a54dd328b56c6f26ba70f4be4b90 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -52,7 +52,9 @@ class DeformableConvV2(nn.Layer): bias_attr=None, lr_scale=1, regularizer=None, - skip_quant=False): + skip_quant=False, + dcn_bias_regularizer=L2Decay(0.), + dcn_bias_lr_scale=2.): super(DeformableConvV2, self).__init__() self.offset_channel = 2 * kernel_size**2 self.mask_channel = kernel_size**2 @@ -79,8 +81,8 @@ class DeformableConvV2(nn.Layer): # in FCOS-DCN head, specifically need learning_rate and regularizer dcn_bias_attr = ParamAttr( initializer=Constant(value=0), - regularizer=L2Decay(0.), - learning_rate=2.) + regularizer=dcn_bias_regularizer, + learning_rate=dcn_bias_lr_scale) else: # in ResNet backbone, do not need bias dcn_bias_attr = False @@ -122,7 +124,9 @@ class ConvNormLayer(nn.Layer): freeze_norm=False, initializer=Normal( mean=0., std=0.01), - skip_quant=False): + skip_quant=False, + dcn_lr_scale=2., + dcn_regularizer=L2Decay(0.)): super(ConvNormLayer, self).__init__() assert norm_type in ['bn', 'sync_bn', 'gn'] @@ -157,15 +161,17 @@ class ConvNormLayer(nn.Layer): weight_attr=ParamAttr( initializer=initializer, learning_rate=1.), bias_attr=True, - lr_scale=2., - regularizer=L2Decay(norm_decay), + lr_scale=dcn_lr_scale, + regularizer=dcn_regularizer, skip_quant=skip_quant) norm_lr = 0. if freeze_norm else 1. param_attr = ParamAttr( - learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay) if norm_decay is not None else None) bias_attr = ParamAttr( - learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay) if norm_decay is not None else None) if norm_type == 'bn': self.norm = nn.BatchNorm2D( ch_out, weight_attr=param_attr, bias_attr=bias_attr) diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 4cbef1cde3197886dc8d8392efe9c0c529dc1a82..7b7ecd63bd85b3b344865b20ed70f89ce60b30c0 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -21,6 +21,7 @@ from . import solov2_loss from . import ctfocal_loss from . import keypoint_loss from . import jde_loss +from . import fairmot_loss from .yolo_loss import * from .iou_aware_loss import * @@ -31,3 +32,4 @@ from .solov2_loss import * from .ctfocal_loss import * from .keypoint_loss import * from .jde_loss import * +from .fairmot_loss import * diff --git a/ppdet/modeling/losses/fairmot_loss.py b/ppdet/modeling/losses/fairmot_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..e24ff33fe341cce9bd4865807922a34bc2a91841 --- /dev/null +++ b/ppdet/modeling/losses/fairmot_loss.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +from paddle.nn.initializer import Constant +from ppdet.core.workspace import register + +__all__ = ['FairMOTLoss'] + + +@register +class FairMOTLoss(nn.Layer): + def __init__(self): + super(FairMOTLoss, self).__init__() + self.det_weight = self.create_parameter( + shape=[1], default_initializer=Constant(-1.85)) + self.reid_weight = self.create_parameter( + shape=[1], default_initializer=Constant(-1.05)) + + def forward(self, det_loss, reid_loss): + loss = paddle.exp(-self.det_weight) * det_loss + paddle.exp( + -self.reid_weight) * reid_loss + (self.det_weight + self.reid_weight + ) + loss *= 0.5 + return {'loss': loss} diff --git a/ppdet/modeling/mot/tracker/jde_tracker.py b/ppdet/modeling/mot/tracker/jde_tracker.py index b0162f7b210ca3d703fc9312aeb23c4ba9178b63..9c9007e91c89bac8f271e8a1fe21f21ef9f365be 100644 --- a/ppdet/modeling/mot/tracker/jde_tracker.py +++ b/ppdet/modeling/mot/tracker/jde_tracker.py @@ -46,6 +46,9 @@ class JDETracker(object): unconfirmed_thresh (float): linear assignment threshold of unconfirmed stracks and unmatched detections motion (object): KalmanFilter instance + conf_thres (float): confidence threshold for tracking + metric_type (str): either "euclidean" or "cosine", the distance metric + used for measurement to track association. """ def __init__(self, @@ -55,7 +58,9 @@ class JDETracker(object): tracked_thresh=0.7, r_tracked_thresh=0.5, unconfirmed_thresh=0.7, - motion='KalmanFilter'): + motion='KalmanFilter', + conf_thres=0, + metric_type='euclidean'): self.det_thresh = det_thresh self.track_buffer = track_buffer self.min_box_area = min_box_area @@ -63,6 +68,8 @@ class JDETracker(object): self.r_tracked_thresh = r_tracked_thresh self.unconfirmed_thresh = unconfirmed_thresh self.motion = motion + self.conf_thres = conf_thres + self.metric_type = metric_type self.frame_id = 0 self.tracked_stracks = [] @@ -96,6 +103,14 @@ class JDETracker(object): # removed. (Lost for some time lesser than the threshold for removing) removed_stracks = [] + remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres) + if remain_inds.shape[0] == 0: + pred_dets = paddle.zeros([0, 1]) + pred_embs = paddle.zeros([0, 1]) + else: + pred_dets = paddle.gather(pred_dets, remain_inds) + pred_embs = paddle.gather(pred_embs, remain_inds) + # Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]] empty_pred = True if len(pred_dets) == 1 and paddle.sum( pred_dets) == 0.0 else False @@ -125,7 +140,8 @@ class JDETracker(object): # Predict the current location with KF STrack.multi_predict(strack_pool, self.motion) - dists = matching.embedding_distance(strack_pool, detections) + dists = matching.embedding_distance( + strack_pool, detections, metric=self.metric_type) dists = matching.fuse_motion(self.motion, dists, strack_pool, detections) # The dists is the list of distances of the detection with the tracks in strack_pool diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 9a0f150570710bb5ef2252c577dcbf35b653aedb..6de12cffb3f0beb10de74597968cf157255377ce 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -16,8 +16,10 @@ from . import fpn from . import yolo_fpn from . import hrfpn from . import ttf_fpn +from . import centernet_fpn from .fpn import * from .yolo_fpn import * from .hrfpn import * from .ttf_fpn import * +from .centernet_fpn import * diff --git a/ppdet/modeling/necks/centernet_fpn.py b/ppdet/modeling/necks/centernet_fpn.py new file mode 100755 index 0000000000000000000000000000000000000000..0ecfbdd1d20c5ff31f8d3eeae975f74550bd68d8 --- /dev/null +++ b/ppdet/modeling/necks/centernet_fpn.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +import paddle +import paddle.nn.functional as F +from paddle import ParamAttr +import paddle.nn as nn +from paddle.nn.initializer import KaimingUniform +from ppdet.core.workspace import register, serializable +from ppdet.modeling.layers import ConvNormLayer +from ..shape_spec import ShapeSpec + + +def fill_up_weights(up): + weight = up.weight + f = math.ceil(weight.shape[2] / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(weight.shape[2]): + for j in range(weight.shape[3]): + weight[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, weight.shape[0]): + weight[c, 0, :, :] = weight[0, 0, :, :] + + +class IDAUp(nn.Layer): + def __init__(self, ch_ins, ch_out, up_strides, dcn_v2=True): + super(IDAUp, self).__init__() + for i in range(1, len(ch_ins)): + ch_in = ch_ins[i] + up_s = int(up_strides[i]) + proj = nn.Sequential( + ConvNormLayer( + ch_in, + ch_out, + filter_size=3, + stride=1, + use_dcn=dcn_v2, + bias_on=dcn_v2, + norm_decay=None, + dcn_lr_scale=1., + dcn_regularizer=None), + nn.ReLU()) + node = nn.Sequential( + ConvNormLayer( + ch_out, + ch_out, + filter_size=3, + stride=1, + use_dcn=dcn_v2, + bias_on=dcn_v2, + norm_decay=None, + dcn_lr_scale=1., + dcn_regularizer=None), + nn.ReLU()) + + param_attr = paddle.ParamAttr(initializer=KaimingUniform()) + up = nn.Conv2DTranspose( + ch_out, + ch_out, + kernel_size=up_s * 2, + weight_attr=param_attr, + stride=up_s, + padding=up_s // 2, + groups=ch_out, + bias_attr=False) + # TODO: uncomment fill_up_weights + #fill_up_weights(up) + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + setattr(self, 'node_' + str(i), node) + + def forward(self, inputs, start_level, end_level): + for i in range(start_level + 1, end_level): + upsample = getattr(self, 'up_' + str(i - start_level)) + project = getattr(self, 'proj_' + str(i - start_level)) + + inputs[i] = project(inputs[i]) + inputs[i] = upsample(inputs[i]) + node = getattr(self, 'node_' + str(i - start_level)) + inputs[i] = node(paddle.add(inputs[i], inputs[i - 1])) + + +class DLAUp(nn.Layer): + def __init__(self, start_level, channels, scales, ch_in=None, dcn_v2=True): + super(DLAUp, self).__init__() + self.start_level = start_level + if ch_in is None: + ch_in = channels + self.channels = channels + channels = list(channels) + scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr( + self, + 'ida_{}'.format(i), + IDAUp( + ch_in[j:], + channels[j], + scales[j:] // scales[j], + dcn_v2=dcn_v2)) + scales[j + 1:] = scales[j] + ch_in[j + 1:] = [channels[j] for _ in channels[j + 1:]] + + def forward(self, inputs): + out = [inputs[-1]] # start with 32 + for i in range(len(inputs) - self.start_level - 1): + ida = getattr(self, 'ida_{}'.format(i)) + ida(inputs, len(inputs) - i - 2, len(inputs)) + out.insert(0, inputs[-1]) + return out + + +@register +@serializable +class CenterNetDLAFPN(nn.Layer): + """ + Args: + in_channels (list): number of input feature channels from backbone. + [16, 32, 64, 128, 256, 512] by default, means the channels of DLA-34 + down_ratio (int): the down ratio from images to heatmap, 4 by default + last_level (int): the last level of input feature fed into the upsamplng block + out_channel (int): the channel of the output feature, 0 by default means + the channel of the input feature whose down ratio is `down_ratio` + dcn_v2 (bool): whether use the DCNv2, true by default + + """ + + def __init__(self, + in_channels, + down_ratio=4, + last_level=5, + out_channel=0, + dcn_v2=True): + super(CenterNetDLAFPN, self).__init__() + self.first_level = int(np.log2(down_ratio)) + self.down_ratio = down_ratio + self.last_level = last_level + scales = [2**i for i in range(len(in_channels[self.first_level:]))] + self.dla_up = DLAUp( + self.first_level, + in_channels[self.first_level:], + scales, + dcn_v2=dcn_v2) + self.out_channel = out_channel + if out_channel == 0: + self.out_channel = in_channels[self.first_level] + self.ida_up = IDAUp( + in_channels[self.first_level:self.last_level], + self.out_channel, + [2**i for i in range(self.last_level - self.first_level)], + dcn_v2=dcn_v2) + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape]} + + def forward(self, body_feats): + dla_up_feats = self.dla_up(body_feats) + + ida_up_feats = [] + for i in range(self.last_level - self.first_level): + ida_up_feats.append(dla_up_feats[i].clone()) + + self.ida_up(ida_up_feats, 0, len(ida_up_feats)) + + return ida_up_feats[-1] + + @property + def out_shape(self): + return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)] diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index eb1f82f9e7275a6a0cfa11222670092243985150..4839328d3c68fe16b5be3d2677e564b63f7799d1 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -18,6 +18,7 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly +from ppdet.modeling.layers import TTFBox try: from collections.abc import Sequence except Exception: @@ -29,6 +30,7 @@ __all__ = [ 'FCOSPostProcess', 'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', + 'CenterNetPostProcess', ] @@ -352,3 +354,95 @@ class JDEBBoxPostProcess(BBoxPostProcess): nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32')) return boxes_idx, bbox_pred, bbox_num, nms_keep_idx + + +@register +class CenterNetPostProcess(TTFBox): + """ + Postprocess the model outputs to get final prediction: + 1. Do NMS for heatmap to get top `max_per_img` bboxes. + 2. Decode bboxes using center offset and box size. + 3. Rescale decoded bboxes reference to the origin image shape. + + Args: + max_per_img(int): the maximum number of predicted objects in a image, + 500 by default. + down_ratio(int): the down ratio from images to heatmap, 4 by default. + regress_ltrb (bool): whether to regress left/top/right/bottom or + width/height for a box, true by default. + for_mot (bool): whether return other features used in tracking model. + + """ + + __shared__ = ['down_ratio'] + + def __init__(self, + max_per_img=500, + down_ratio=4, + regress_ltrb=True, + for_mot=False): + super(TTFBox, self).__init__() + self.max_per_img = max_per_img + self.down_ratio = down_ratio + self.regress_ltrb = regress_ltrb + self.for_mot = for_mot + + def __call__(self, hm, wh, reg, im_shape, scale_factor): + heat = self._simple_nms(hm) + scores, inds, clses, ys, xs = self._topk(heat) + scores = paddle.tensor.unsqueeze(scores, [1]) + clses = paddle.tensor.unsqueeze(clses, [1]) + + reg_t = paddle.transpose(reg, [0, 2, 3, 1]) + # Like TTFBox, batch size is 1. + # TODO: support batch size > 1 + reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]]) + reg = paddle.gather(reg, inds) + xs = paddle.cast(xs, 'float32') + ys = paddle.cast(ys, 'float32') + xs = xs + reg[:, 0:1] + ys = ys + reg[:, 1:2] + + wh_t = paddle.transpose(wh, [0, 2, 3, 1]) + wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]]) + wh = paddle.gather(wh, inds) + + if self.regress_ltrb: + x1 = xs - wh[:, 0:1] + y1 = ys - wh[:, 1:2] + x2 = xs + wh[:, 2:3] + y2 = ys + wh[:, 3:4] + else: + x1 = xs - wh[:, 0:1] / 2 + y1 = ys - wh[:, 1:2] / 2 + x2 = xs + wh[:, 0:1] / 2 + y2 = ys + wh[:, 1:2] / 2 + + n, c, feat_h, feat_w = paddle.shape(hm) + padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2 + padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2 + x1 = x1 * self.down_ratio + y1 = y1 * self.down_ratio + x2 = x2 * self.down_ratio + y2 = y2 * self.down_ratio + + x1 = x1 - padw + y1 = y1 - padh + x2 = x2 - padw + y2 = y2 - padh + + bboxes = paddle.concat([x1, y1, x2, y2], axis=1) + scale_y = scale_factor[:, 0:1] + scale_x = scale_factor[:, 1:2] + scale_expand = paddle.concat( + [scale_x, scale_y, scale_x, scale_y], axis=1) + boxes_shape = paddle.shape(bboxes) + boxes_shape.stop_gradient = True + scale_expand = paddle.expand(scale_expand, shape=boxes_shape) + bboxes = paddle.divide(bboxes, scale_expand) + if self.for_mot: + results = paddle.concat([bboxes, scores, clses], axis=1) + return results, inds + else: + results = paddle.concat([clses, scores, bboxes], axis=1) + return results, paddle.shape(results)[0:1] diff --git a/ppdet/modeling/reid/__init__.py b/ppdet/modeling/reid/__init__.py index a000c95329bbe3da1398df8bcded72e94e505d6b..33309e244547acf5f8e372a8c4dd887da3003589 100644 --- a/ppdet/modeling/reid/__init__.py +++ b/ppdet/modeling/reid/__init__.py @@ -15,7 +15,9 @@ from . import jde_embedding_head from . import pyramidal_embedding from . import resnet +from . import fairmot_embedding_head from .jde_embedding_head import * from .pyramidal_embedding import * from .resnet import * +from .fairmot_embedding_head import * diff --git a/ppdet/modeling/reid/fairmot_embedding_head.py b/ppdet/modeling/reid/fairmot_embedding_head.py new file mode 100755 index 0000000000000000000000000000000000000000..19ca28080a997e87b2a7a3a3e7a1695d2c2beea4 --- /dev/null +++ b/ppdet/modeling/reid/fairmot_embedding_head.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import KaimingUniform, Uniform +from ppdet.core.workspace import register +from ppdet.modeling.heads.centernet_head import ConvLayer + +__all__ = ['FairMOTEmbeddingHead'] + + +@register +class FairMOTEmbeddingHead(nn.Layer): + """ + Args: + in_channels (int): the channel number of input to FairMOTEmbeddingHead. + ch_head (int): the channel of features before fed into embedding, 256 by default. + ch_emb (int): the channel of the embedding feature, 128 by default. + num_identifiers (int): the number of identifiers, 14455 by default. + + """ + + def __init__(self, + in_channels, + ch_head=256, + ch_emb=128, + num_identifiers=14455): + super(FairMOTEmbeddingHead, self).__init__() + self.reid = nn.Sequential( + ConvLayer( + in_channels, ch_head, kernel_size=3, padding=1, bias=True), + nn.ReLU(), + ConvLayer( + ch_head, ch_emb, kernel_size=1, stride=1, padding=0, bias=True)) + param_attr = paddle.ParamAttr(initializer=KaimingUniform()) + bound = 1 / math.sqrt(ch_emb) + bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound)) + self.classifier = nn.Linear( + ch_emb, + num_identifiers, + weight_attr=param_attr, + bias_attr=bias_attr) + self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') + # When num_identifiers is 1, emb_scale is set as 1 + self.emb_scale = math.sqrt(2) * math.log( + num_identifiers - 1) if num_identifiers > 1 else 1 + + @classmethod + def from_config(cls, cfg, input_shape): + if isinstance(input_shape, (list, tuple)): + input_shape = input_shape[0] + return {'in_channels': input_shape.channels} + + def forward(self, feat, inputs): + reid_feat = self.reid(feat) + if self.training: + loss = self.get_loss(reid_feat, inputs) + return loss + else: + reid_feat = F.normalize(reid_feat) + return reid_feat + + def get_loss(self, feat, inputs): + index = inputs['index'] + mask = inputs['index_mask'] + target = inputs['reid'] + target = paddle.masked_select(target, mask > 0) + target = paddle.unsqueeze(target, 1) + + feat = paddle.transpose(feat, perm=[0, 2, 3, 1]) + feat_n, feat_h, feat_w, feat_c = feat.shape + feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c]) + index = paddle.unsqueeze(index, 2) + batch_inds = list() + for i in range(feat_n): + batch_ind = paddle.full( + shape=[1, index.shape[1], 1], fill_value=i, dtype='int64') + batch_inds.append(batch_ind) + batch_inds = paddle.concat(batch_inds, axis=0) + index = paddle.concat(x=[batch_inds, index], axis=2) + feat = paddle.gather_nd(feat, index=index) + + mask = paddle.unsqueeze(mask, axis=2) + mask = paddle.expand_as(mask, feat) + mask.stop_gradient = True + feat = paddle.masked_select(feat, mask > 0) + feat = paddle.reshape(feat, shape=[-1, feat_c]) + feat = F.normalize(feat) + feat = self.emb_scale * feat + logit = self.classifier(feat) + target.stop_gradient = True + loss = self.reid_loss(logit, target) + valid = (target != self.reid_loss.ignore_index) + valid.stop_gradient = True + count = paddle.sum((paddle.cast(valid, dtype=np.int32))) + count.stop_gradient = True + if count > 0: + loss = loss / count + + return loss