未验证 提交 277bc13f 编写于 作者: F FlyingQianMM 提交者: GitHub

[MOT] add FairMOT (#2994)

* add fairmot

* rename pretrain_weights

* add comments; reuse focal loss and convnorm

* update comments
上级 5ad5a819
...@@ -12,7 +12,7 @@ MOTDataZoo: { ...@@ -12,7 +12,7 @@ MOTDataZoo: {
TrainDataset: TrainDataset:
!MOTDataSet !MOTDataSet
dataset_dir: dataset/mot dataset_dir: dataset/mot
image_lists: ['mot17.train', 'caltech.train', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train'] image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
EvalDataset: EvalDataset:
......
# FairMOT (FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Getting Start](#Getting_Start)
- [Citations](#Citations)
## Introduction
FairMOT focuses on accomplishing the detection and re-identification in a single network to improve the inference speed, presents a simple baseline which consists of two homogeneous branches to predict pixel-wise objectness scores and re-ID features. The achieved fairness between the two tasks allows FairMOT to obtain high levels of detection and tracking accuracy.
## Model Zoo
### Results on MOT-16 train set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | download | config |
| :-----------------| :------- | :----: | :----: | :---: | :----: | :---: |:---: | :---: |
| DLA-34(paper) | 1088x608 | 83.3 | 81.9 | 544 | 3822 | 14095 | ---- | ---- |
| DLA-34 | 1088x608 | 83.4 | 82.7 | 517 | 4077 | 13761 | [model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml) |
### Results on MOT-16 test set
| backbone | input shape | MOTA | IDF1 | IDS | MT | ML | download | config |
| :-----------------| :------- | :----: | :----: | :---: | :----: | :---: | :---: | :---: |
| DLA-34(paper) | 1088x608 | 74.9 72.8 1074 44.7% 15.9% | ---- | ---- |
| DLA-34 | 1088x608 | 74.7 | 72.8 | 1044 | 41.9% | 19.1% |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml) |
**Notes:**
FairMOT used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches.
## Getting Start
### 1. Training
Training FairMOT on 2 GPUs with following command
```bash
python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608/ --gpus 0,1 tools/train.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml &>fairmot_dla34_30e_1088x608.log 2>&1 &
```
### 2. Evaluation
Evaluating the track performance of FairMOT on val dataset in single GPU with following commands:
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams
# use saved checkpoint in training
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=output/fairmot_dla34_30e_1088x608/model_final
```
## Citations
```
@article{zhang2020fair,
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
```
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker
CenterNet:
backbone: DLA
neck: CenterNetDLAFPN
head: CenterNetHead
post_process: CenterNetPostProcess
for_mot: True
CenterNetPostProcess:
for_mot: True
JDETracker:
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
worker_num: 4
TrainReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {to_rgb: False}
- AugmentHSV: {is_bgr: True}
- LetterBoxResize: {target_size: [608, 1088]}
- MOTRandomAffine: {reject_outside: False}
- RandomFlip: {}
- BboxXYXY2XYWH: {}
- NormalizeBox: {}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
- Permute: {to_rgb: True}
batch_transforms:
- Gt2FairMOTTarget: {}
batch_size: 6
shuffle: True
drop_last: True
EvalMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {to_rgb: False}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
- Permute: {to_rgb: True}
batch_size: 1
epoch: 30
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [20,]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/optimizer_30e.yml',
'_base_/fairmot_dla34.yml',
'_base_/fairmot_reader_1088x608.yml',
]
metric: MOT
weights: output/fairmot_dla34_30e_1088x608/model_final
...@@ -347,6 +347,7 @@ class MOTVideoDataset(DetDataset): ...@@ -347,6 +347,7 @@ class MOTVideoDataset(DetDataset):
def _load_video_images(self): def _load_video_images(self):
self.cap = cv2.VideoCapture(self.video_file) self.cap = cv2.VideoCapture(self.video_file)
self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
logger.info('Length of the video: {:d} frames'.format(self.vn)) logger.info('Length of the video: {:d} frames'.format(self.vn))
res = True res = True
ct = 0 ct = 0
......
...@@ -25,14 +25,19 @@ from numbers import Integral ...@@ -25,14 +25,19 @@ from numbers import Integral
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import math
from .operators import BaseOperator, register_op from .operators import BaseOperator, register_op
from .batch_operators import Gt2TTFTarget
from ppdet.modeling.bbox_utils import bbox_iou_np_expand from ppdet.modeling.bbox_utils import bbox_iou_np_expand
from ppdet.core.workspace import serializable from ppdet.core.workspace import serializable
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
__all__ = ['LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax'] __all__ = [
'LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax',
'Gt2FairMOTTarget'
]
@register_op @register_op
...@@ -367,3 +372,117 @@ class Gt2JDETargetMax(BaseOperator): ...@@ -367,3 +372,117 @@ class Gt2JDETargetMax(BaseOperator):
sample['tbox{}'.format(i)] = tbox sample['tbox{}'.format(i)] = tbox
sample['tconf{}'.format(i)] = tconf sample['tconf{}'.format(i)] = tconf
sample['tide{}'.format(i)] = tid sample['tide{}'.format(i)] = tid
class Gt2FairMOTTarget(Gt2TTFTarget):
__shared__ = ['num_classes']
"""
Generate FairMOT targets by ground truth data.
Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
1. the gaussian kernal radius to generate a heatmap.
2. the targets needed during traing.
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
"""
def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
super(Gt2TTFTarget, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.max_objs = max_objs
def __call__(self, samples, context=None):
for b_id, sample in enumerate(samples):
output_h = sample['image'].shape[1] // self.down_ratio
output_w = sample['image'].shape[2] // self.down_ratio
heatmap = np.zeros(
(self.num_classes, output_h, output_w), dtype='float32')
bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
index = np.zeros((self.max_objs, ), dtype=np.int64)
index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
reid = np.zeros((self.max_objs, ), dtype=np.int64)
bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
gt_ide = sample['gt_ide']
for k in range(len(gt_bbox)):
cls_id = gt_class[k][0]
bbox = gt_bbox[k]
ide = gt_ide[k][0]
bbox[[0, 2]] = bbox[[0, 2]] * output_w
bbox[[1, 3]] = bbox[[1, 3]] * output_h
bbox_amodal = copy.deepcopy(bbox)
bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
bbox[0] = np.clip(bbox[0], 0, output_w - 1)
bbox[1] = np.clip(bbox[1], 0, output_h - 1)
h = bbox[3]
w = bbox[2]
bbox_xy = copy.deepcopy(bbox)
bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
if h > 0 and w > 0:
radius = self.gaussian_radius((math.ceil(h), math.ceil(w)))
radius = max(0, int(radius))
ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
ct_int = ct.astype(np.int32)
self.draw_truncate_gaussian(heatmap[cls_id], ct_int, radius,
radius)
bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
index[k] = ct_int[1] * output_w + ct_int[0]
center_offset[k] = ct - ct_int
index_mask[k] = 1
reid[k] = ide
bbox_xys[k] = bbox_xy
sample['heatmap'] = heatmap
sample['index'] = index
sample['offset'] = center_offset
sample['size'] = bbox_size
sample['index_mask'] = index_mask
sample['reid'] = reid
sample['bbox_xys'] = bbox_xys
sample.pop('is_crowd', None)
sample.pop('difficult', None)
sample.pop('gt_class', None)
sample.pop('gt_bbox', None)
sample.pop('gt_score', None)
sample.pop('gt_ide', None)
return samples
def gaussian_radius(self, det_size, min_overlap=0.7):
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)
...@@ -107,10 +107,12 @@ class BaseOperator(object): ...@@ -107,10 +107,12 @@ class BaseOperator(object):
@register_op @register_op
class Decode(BaseOperator): class Decode(BaseOperator):
def __init__(self): def __init__(self, to_rgb=True):
""" Transform the image data to numpy format following the rgb format """ Transform the image data to numpy format following the rgb format
""" """
super(Decode, self).__init__() super(Decode, self).__init__()
# TODO: remove this parameter
self.to_rgb = to_rgb
def apply(self, sample, context=None): def apply(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is""" """ load image if 'im_file' field is not empty but 'image' is"""
...@@ -124,6 +126,7 @@ class Decode(BaseOperator): ...@@ -124,6 +126,7 @@ class Decode(BaseOperator):
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
if 'keep_ori_im' in sample and sample['keep_ori_im']: if 'keep_ori_im' in sample and sample['keep_ori_im']:
sample['ori_image'] = im sample['ori_image'] = im
if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
sample['image'] = im sample['image'] = im
...@@ -151,14 +154,18 @@ class Decode(BaseOperator): ...@@ -151,14 +154,18 @@ class Decode(BaseOperator):
@register_op @register_op
class Permute(BaseOperator): class Permute(BaseOperator):
def __init__(self): def __init__(self, to_rgb=False):
""" """
Change the channel to be (C, H, W) Change the channel to be (C, H, W)
""" """
super(Permute, self).__init__() super(Permute, self).__init__()
# TODO: remove this parameter
self.to_rgb = to_rgb
def apply(self, sample, context=None): def apply(self, sample, context=None):
im = sample['image'] im = sample['image']
if self.to_rgb:
im = np.ascontiguousarray(im[:, :, ::-1])
im = im.transpose((2, 0, 1)) im = im.transpose((2, 0, 1))
sample['image'] = im sample['image'] = im
return sample return sample
...@@ -2076,29 +2083,39 @@ class Norm2PixelBbox(BaseOperator): ...@@ -2076,29 +2083,39 @@ class Norm2PixelBbox(BaseOperator):
@register_op @register_op
class MOTRandomAffine(BaseOperator): class MOTRandomAffine(BaseOperator):
def __init__(self,
degrees=(-5, 5),
translate=(0.10, 0.10),
scale=(0.50, 1.20),
shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5)):
""" """
Affine transform to image and coords to achieve the rotate, scale and Affine transform to image and coords to achieve the rotate, scale and
shift effect for training image. shift effect for training image.
Args: Args:
degrees (tuple): rotation value degrees (list[2]): the rotate range to apply, transform range is [min, max]
translate (tuple): xy coords translation value translate (list[2]): the translate range to apply, ransform range is [min, max]
scale (tuple): scale value scale (list[2]): the scale range to apply, transform range is [min, max]
shear (tuple): shear value shear (list[2]): the shear range to apply, transform range is [min, max]
borderValue (tuple): border color value borderValue (list[3]): value used in case of a constant border when appling
the perspective transformation
reject_outside (bool): reject warped bounding bboxes outside of image
Returns:
records(dict): contain the image and coords after tranformed
""" """
def __init__(self,
degrees=(-5, 5),
translate=(0.10, 0.10),
scale=(0.50, 1.20),
shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5),
reject_outside=True):
super(MOTRandomAffine, self).__init__() super(MOTRandomAffine, self).__init__()
self.degrees = degrees self.degrees = degrees
self.translate = translate self.translate = translate
self.scale = scale self.scale = scale
self.shear = shear self.shear = shear
self.borderValue = borderValue self.borderValue = borderValue
self.reject_outside = reject_outside
def apply(self, sample, context=None): def apply(self, sample, context=None):
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
...@@ -2171,6 +2188,7 @@ class MOTRandomAffine(BaseOperator): ...@@ -2171,6 +2188,7 @@ class MOTRandomAffine(BaseOperator):
(x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
# reject warped points outside of image # reject warped points outside of image
if self.reject_outside:
np.clip(xy[:, 0], 0, width, out=xy[:, 0]) np.clip(xy[:, 0], 0, width, out=xy[:, 0])
np.clip(xy[:, 2], 0, width, out=xy[:, 2]) np.clip(xy[:, 2], 0, width, out=xy[:, 2])
np.clip(xy[:, 1], 0, height, out=xy[:, 1]) np.clip(xy[:, 1], 0, height, out=xy[:, 1])
......
...@@ -66,6 +66,10 @@ class Trainer(object): ...@@ -66,6 +66,10 @@ class Trainer(object):
cfg['JDEEmbeddingHead'][ cfg['JDEEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities 'num_identifiers'] = self.dataset.total_identities
if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities
# build model # build model
if 'model' not in self.cfg: if 'model' not in self.cfg:
self.model = create(cfg.architecture) self.model = create(cfg.architecture)
...@@ -223,6 +227,9 @@ class Trainer(object): ...@@ -223,6 +227,9 @@ class Trainer(object):
return return
self.start_epoch = 0 self.start_epoch = 0
if hasattr(self.model, 'detector'): if hasattr(self.model, 'detector'):
if self.model.__class__.__name__ == 'FairMOT':
load_pretrain_weight(self.model, weights)
else:
load_pretrain_weight(self.model.detector, weights) load_pretrain_weight(self.model.detector, weights)
else: else:
load_pretrain_weight(self.model, weights) load_pretrain_weight(self.model, weights)
......
...@@ -19,6 +19,8 @@ from . import keypoint_hrhrnet ...@@ -19,6 +19,8 @@ from . import keypoint_hrhrnet
from . import keypoint_hrnet from . import keypoint_hrnet
from . import jde from . import jde
from . import deepsort from . import deepsort
from . import fairmot
from . import centernet
from .meta_arch import * from .meta_arch import *
from .faster_rcnn import * from .faster_rcnn import *
...@@ -34,3 +36,5 @@ from .keypoint_hrhrnet import * ...@@ -34,3 +36,5 @@ from .keypoint_hrhrnet import *
from .keypoint_hrnet import * from .keypoint_hrnet import *
from .jde import * from .jde import *
from .deepsort import * from .deepsort import *
from .fairmot import *
from .centernet import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['CenterNet']
@register
class CenterNet(BaseArch):
"""
CenterNet network, see http://arxiv.org/abs/1904.07850
Args:
backbone (object): backbone instance
neck (object): 'CenterDLAFPN' instance
head (object): 'CenterHead' instance
post_process (object): 'CenterNetPostProcess' instance
for_mot (bool): whether return other features used in tracking model
"""
__category__ = 'architecture'
__inject__ = ['post_process']
def __init__(self,
backbone='DLA',
neck='CenterDLAFPN',
head='CenterHead',
post_process='CenterNetPostProcess',
for_mot=False):
super(CenterNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.post_process = post_process
self.for_mot = for_mot
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {'backbone': backbone, 'neck': neck, "head": head}
def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feat = self.neck(body_feats)
head_out = self.head(neck_feat, self.inputs)
if self.for_mot:
head_out.update({'neck_feat': neck_feat})
return head_out
def get_pred(self):
head_out = self._forward()
if self.for_mot:
bbox, bbox_inds = self.post_process(
head_out['heatmap'],
head_out['size'],
head_out['offset'],
im_shape=self.inputs['im_shape'],
scale_factor=self.inputs['scale_factor'])
output = {
"bbox": bbox,
"bbox_inds": bbox_inds,
"neck_feat": head_out['neck_feat']
}
else:
bbox, bbox_num = self.post_process(
head_out['heatmap'],
head_out['size'],
head_out['offset'],
im_shape=self.inputs['im_shape'],
scale_factor=self.inputs['scale_factor'])
output = {
"bbox": bbox,
"bbox_num": bbox_num,
}
return output
def get_loss(self):
return self._forward()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['FairMOT']
@register
class FairMOT(BaseArch):
"""
FairMOT network, see http://arxiv.org/abs/2004.01888
Args:
detector (object): 'CenterNet' instance
reid (object): 'FairMOTEmbeddingHead' instance
tracker (object): 'JDETracker' instance
loss (object): 'FairMOTLoss' instance
"""
__category__ = 'architecture'
__inject__ = ['loss']
def __init__(self,
detector='CenterNet',
reid='FairMOTEmbeddingHead',
tracker='JDETracker',
loss='FairMOTLoss'):
super(FairMOT, self).__init__()
self.detector = detector
self.reid = reid
self.tracker = tracker
self.loss = loss
@classmethod
def from_config(cls, cfg, *args, **kwargs):
detector = create(cfg['detector'])
kwargs = {'input_shape': detector.neck.out_shape}
reid = create(cfg['reid'], **kwargs)
loss = create(cfg['loss'])
tracker = create(cfg['tracker'])
return {
'detector': detector,
'reid': reid,
'loss': loss,
'tracker': tracker
}
def _forward(self):
loss = dict()
# det_outs keys:
# train: det_loss, heatmap_loss, size_loss, offset_loss, neck_feat
# eval/infer: bbox, bbox_inds, neck_feat
det_outs = self.detector(self.inputs)
neck_feat = det_outs['neck_feat']
if self.training:
reid_loss = self.reid(neck_feat, self.inputs)
det_loss = det_outs['det_loss']
loss = self.loss(det_loss, reid_loss)
loss.update({
'heatmap_loss': det_outs['heatmap_loss'],
'size_loss': det_outs['size_loss'],
'offset_loss': det_outs['offset_loss'],
'reid_loss': reid_loss
})
return loss
else:
embedding = self.reid(neck_feat, self.inputs)
bbox_inds = det_outs['bbox_inds']
embedding = paddle.transpose(embedding, [0, 2, 3, 1])
embedding = paddle.reshape(embedding,
[-1, paddle.shape(embedding)[-1]])
id_feature = paddle.gather(embedding, bbox_inds)
dets = det_outs['bbox']
id_feature = id_feature
# Note: the tracker only considers batch_size=1 and num_classses=1
online_targets = self.tracker.update(dets, id_feature)
return online_targets
def get_pred(self):
output = self._forward()
return output
def get_loss(self):
loss = self._forward()
return loss
...@@ -22,6 +22,7 @@ from . import blazenet ...@@ -22,6 +22,7 @@ from . import blazenet
from . import ghostnet from . import ghostnet
from . import senet from . import senet
from . import res2net from . import res2net
from . import dla
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -33,3 +34,4 @@ from .blazenet import * ...@@ -33,3 +34,4 @@ from .blazenet import *
from .ghostnet import * from .ghostnet import *
from .senet import * from .senet import *
from .res2net import * from .res2net import *
from .dla import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512])}
class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = ConvNormLayer(
ch_in,
ch_out,
filter_size=3,
stride=stride,
bias_on=False,
norm_decay=None)
self.conv2 = ConvNormLayer(
ch_out,
ch_out,
filter_size=3,
stride=1,
bias_on=False,
norm_decay=None)
def forward(self, inputs, residual=None):
if residual is None:
residual = inputs
out = self.conv1(inputs)
out = F.relu(out)
out = self.conv2(out)
out = paddle.add(x=out, y=residual)
out = F.relu(out)
return out
class Root(nn.Layer):
def __init__(self, ch_in, ch_out, kernel_size, residual):
super(Root, self).__init__()
self.conv = ConvNormLayer(
ch_in,
ch_out,
filter_size=1,
stride=1,
bias_on=False,
norm_decay=None)
self.residual = residual
def forward(self, inputs):
children = inputs
out = self.conv(paddle.concat(inputs, axis=1))
if self.residual:
out = paddle.add(x=out, y=children[0])
out = F.relu(out)
return out
class Tree(nn.Layer):
def __init__(self,
level,
block,
ch_in,
ch_out,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
root_residual=False):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * ch_out
if level_root:
root_dim += ch_in
if level == 1:
self.tree1 = block(ch_in, ch_out, stride)
self.tree2 = block(ch_out, ch_out, 1)
else:
self.tree1 = Tree(
level - 1,
block,
ch_in,
ch_out,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
root_residual=root_residual)
self.tree2 = Tree(
level - 1,
block,
ch_out,
ch_out,
1,
root_dim=root_dim + ch_out,
root_kernel_size=root_kernel_size,
root_residual=root_residual)
if level == 1:
self.root = Root(root_dim, ch_out, root_kernel_size, root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.level = level
if stride > 1:
self.downsample = nn.MaxPool2D(stride, stride=stride)
if ch_in != ch_out:
self.project = ConvNormLayer(
ch_in,
ch_out,
filter_size=1,
stride=1,
bias_on=False,
norm_decay=None)
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.level == 1:
x2 = self.tree2(x1)
x = self.root([x2, x1] + children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
@register
@serializable
class DLA(nn.Layer):
"""
DLA, see https://arxiv.org/pdf/1707.06484.pdf
Args:
depth (int): DLA depth, should be 34.
residual_root (bool): whether use a reidual layer in the root block
"""
def __init__(self, depth=34, residual_root=False):
super(DLA, self).__init__()
levels, channels = DLA_cfg[depth]
if depth == 34:
block = BasicBlock
self.channels = channels
self.base_layer = nn.Sequential(
ConvNormLayer(
3,
channels[0],
filter_size=7,
stride=1,
bias_on=False,
norm_decay=None),
nn.ReLU())
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(
channels[0], channels[1], levels[1], stride=2)
self.level2 = Tree(
levels[2],
block,
channels[1],
channels[2],
2,
level_root=False,
root_residual=residual_root)
self.level3 = Tree(
levels[3],
block,
channels[2],
channels[3],
2,
level_root=True,
root_residual=residual_root)
self.level4 = Tree(
levels[4],
block,
channels[3],
channels[4],
2,
level_root=True,
root_residual=residual_root)
self.level5 = Tree(
levels[5],
block,
channels[4],
channels[5],
2,
level_root=True,
root_residual=residual_root)
def _make_conv_level(self, ch_in, ch_out, conv_num, stride=1):
modules = []
for i in range(conv_num):
modules.extend([
ConvNormLayer(
ch_in,
ch_out,
filter_size=3,
stride=stride if i == 0 else 1,
bias_on=False,
norm_decay=None), nn.ReLU()
])
ch_in = ch_out
return nn.Sequential(*modules)
@property
def out_shape(self):
return [ShapeSpec(channels=self.channels[i]) for i in range(6)]
def forward(self, inputs):
outs = []
im = inputs['image']
feats = self.base_layer(im)
for i in range(6):
feats = getattr(self, 'level{}'.format(i))(feats)
outs.append(feats)
return outs
...@@ -24,6 +24,7 @@ from . import cascade_head ...@@ -24,6 +24,7 @@ from . import cascade_head
from . import face_head from . import face_head
from . import s2anet_head from . import s2anet_head
from . import keypoint_hrhrnet_head from . import keypoint_hrhrnet_head
from . import centernet_head
from .bbox_head import * from .bbox_head import *
from .mask_head import * from .mask_head import *
...@@ -37,3 +38,4 @@ from .cascade_head import * ...@@ -37,3 +38,4 @@ from .cascade_head import *
from .face_head import * from .face_head import *
from .s2anet_head import * from .s2anet_head import *
from .keypoint_hrhrnet_head import * from .keypoint_hrhrnet_head import *
from .centernet_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import KaimingUniform
from ppdet.core.workspace import register
from ppdet.modeling.losses import CTFocalLoss
class ConvLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False):
super(ConvLayer, self).__init__()
bias_attr = False
fan_in = ch_in * kernel_size**2
bound = 1 / math.sqrt(fan_in)
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
if bias:
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.Uniform(-bound, bound))
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=param_attr,
bias_attr=bias_attr)
def forward(self, inputs):
out = self.conv(inputs)
return out
@register
class CenterNetHead(nn.Layer):
"""
Args:
in_channels (int): the channel number of input to CenterNetHead.
num_classes (int): the number of classes, 80 by default.
head_planes (int): the channel number in all head, 256 by default.
heatmap_weight (float): the weight of heatmap loss, 1 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default
size_weight (float): the weight of box size loss, 0.1 by default.
offset_weight (float): the weight of center offset loss, 1 by default.
"""
__shared__ = ['num_classes']
def __init__(self,
in_channels,
num_classes=80,
head_planes=256,
heatmap_weight=1,
regress_ltrb=True,
size_weight=0.1,
offset_weight=1):
super(CenterNetHead, self).__init__()
self.weights = {
'heatmap': heatmap_weight,
'size': size_weight,
'offset': offset_weight
}
self.heatmap = nn.Sequential(
ConvLayer(
in_channels, head_planes, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
head_planes,
num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.heatmap[2].conv.bias[:] = -2.19
self.size = nn.Sequential(
ConvLayer(
in_channels, head_planes, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
head_planes,
4 if regress_ltrb else 2,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.offset = nn.Sequential(
ConvLayer(
in_channels, head_planes, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True))
self.focal_loss = CTFocalLoss()
@classmethod
def from_config(cls, cfg, input_shape):
if isinstance(input_shape, (list, tuple)):
input_shape = input_shape[0]
return {'in_channels': input_shape.channels}
def forward(self, feat, inputs):
heatmap = self.heatmap(feat)
size = self.size(feat)
offset = self.offset(feat)
if self.training:
loss = self.get_loss(heatmap, size, offset, self.weights, inputs)
return loss
else:
heatmap = F.sigmoid(heatmap)
return {'heatmap': heatmap, 'size': size, 'offset': offset}
def get_loss(self, heatmap, size, offset, weights, inputs):
heatmap_target = inputs['heatmap']
size_target = inputs['size']
offset_target = inputs['offset']
index = inputs['index']
mask = inputs['index_mask']
heatmap = paddle.clip(F.sigmoid(heatmap), 1e-4, 1 - 1e-4)
heatmap_loss = self.focal_loss(heatmap, heatmap_target)
size = paddle.transpose(size, perm=[0, 2, 3, 1])
size_n, size_h, size_w, size_c = size.shape
size = paddle.reshape(size, shape=[size_n, -1, size_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(size_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
pos_size = paddle.gather_nd(size, index=index)
mask = paddle.unsqueeze(mask, axis=2)
size_mask = paddle.expand_as(mask, pos_size)
size_mask = paddle.cast(size_mask, dtype=pos_size.dtype)
pos_num = size_mask.sum()
size_mask.stop_gradient = True
size_target.stop_gradient = True
size_loss = F.l1_loss(
pos_size * size_mask, size_target * size_mask, reduction='sum')
size_loss = size_loss / (pos_num + 1e-4)
offset = paddle.transpose(offset, perm=[0, 2, 3, 1])
offset_n, offset_h, offset_w, offset_c = offset.shape
offset = paddle.reshape(offset, shape=[offset_n, -1, offset_c])
pos_offset = paddle.gather_nd(offset, index=index)
offset_mask = paddle.expand_as(mask, pos_offset)
offset_mask = paddle.cast(offset_mask, dtype=pos_offset.dtype)
pos_num = offset_mask.sum()
offset_mask.stop_gradient = True
offset_target.stop_gradient = True
offset_loss = F.l1_loss(
pos_offset * offset_mask,
offset_target * offset_mask,
reduction='sum')
offset_loss = offset_loss / (pos_num + 1e-4)
det_loss = weights['heatmap'] * heatmap_loss + weights[
'size'] * size_loss + weights['offset'] * offset_loss
return {
'det_loss': det_loss,
'heatmap_loss': heatmap_loss,
'size_loss': size_loss,
'offset_loss': offset_loss
}
...@@ -52,7 +52,9 @@ class DeformableConvV2(nn.Layer): ...@@ -52,7 +52,9 @@ class DeformableConvV2(nn.Layer):
bias_attr=None, bias_attr=None,
lr_scale=1, lr_scale=1,
regularizer=None, regularizer=None,
skip_quant=False): skip_quant=False,
dcn_bias_regularizer=L2Decay(0.),
dcn_bias_lr_scale=2.):
super(DeformableConvV2, self).__init__() super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2 self.offset_channel = 2 * kernel_size**2
self.mask_channel = kernel_size**2 self.mask_channel = kernel_size**2
...@@ -79,8 +81,8 @@ class DeformableConvV2(nn.Layer): ...@@ -79,8 +81,8 @@ class DeformableConvV2(nn.Layer):
# in FCOS-DCN head, specifically need learning_rate and regularizer # in FCOS-DCN head, specifically need learning_rate and regularizer
dcn_bias_attr = ParamAttr( dcn_bias_attr = ParamAttr(
initializer=Constant(value=0), initializer=Constant(value=0),
regularizer=L2Decay(0.), regularizer=dcn_bias_regularizer,
learning_rate=2.) learning_rate=dcn_bias_lr_scale)
else: else:
# in ResNet backbone, do not need bias # in ResNet backbone, do not need bias
dcn_bias_attr = False dcn_bias_attr = False
...@@ -122,7 +124,9 @@ class ConvNormLayer(nn.Layer): ...@@ -122,7 +124,9 @@ class ConvNormLayer(nn.Layer):
freeze_norm=False, freeze_norm=False,
initializer=Normal( initializer=Normal(
mean=0., std=0.01), mean=0., std=0.01),
skip_quant=False): skip_quant=False,
dcn_lr_scale=2.,
dcn_regularizer=L2Decay(0.)):
super(ConvNormLayer, self).__init__() super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn'] assert norm_type in ['bn', 'sync_bn', 'gn']
...@@ -157,15 +161,17 @@ class ConvNormLayer(nn.Layer): ...@@ -157,15 +161,17 @@ class ConvNormLayer(nn.Layer):
weight_attr=ParamAttr( weight_attr=ParamAttr(
initializer=initializer, learning_rate=1.), initializer=initializer, learning_rate=1.),
bias_attr=True, bias_attr=True,
lr_scale=2., lr_scale=dcn_lr_scale,
regularizer=L2Decay(norm_decay), regularizer=dcn_regularizer,
skip_quant=skip_quant) skip_quant=skip_quant)
norm_lr = 0. if freeze_norm else 1. norm_lr = 0. if freeze_norm else 1.
param_attr = ParamAttr( param_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) learning_rate=norm_lr,
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
bias_attr = ParamAttr( bias_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) learning_rate=norm_lr,
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
if norm_type == 'bn': if norm_type == 'bn':
self.norm = nn.BatchNorm2D( self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr) ch_out, weight_attr=param_attr, bias_attr=bias_attr)
......
...@@ -21,6 +21,7 @@ from . import solov2_loss ...@@ -21,6 +21,7 @@ from . import solov2_loss
from . import ctfocal_loss from . import ctfocal_loss
from . import keypoint_loss from . import keypoint_loss
from . import jde_loss from . import jde_loss
from . import fairmot_loss
from .yolo_loss import * from .yolo_loss import *
from .iou_aware_loss import * from .iou_aware_loss import *
...@@ -31,3 +32,4 @@ from .solov2_loss import * ...@@ -31,3 +32,4 @@ from .solov2_loss import *
from .ctfocal_loss import * from .ctfocal_loss import *
from .keypoint_loss import * from .keypoint_loss import *
from .jde_loss import * from .jde_loss import *
from .fairmot_loss import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.nn.initializer import Constant
from ppdet.core.workspace import register
__all__ = ['FairMOTLoss']
@register
class FairMOTLoss(nn.Layer):
def __init__(self):
super(FairMOTLoss, self).__init__()
self.det_weight = self.create_parameter(
shape=[1], default_initializer=Constant(-1.85))
self.reid_weight = self.create_parameter(
shape=[1], default_initializer=Constant(-1.05))
def forward(self, det_loss, reid_loss):
loss = paddle.exp(-self.det_weight) * det_loss + paddle.exp(
-self.reid_weight) * reid_loss + (self.det_weight + self.reid_weight
)
loss *= 0.5
return {'loss': loss}
...@@ -46,6 +46,9 @@ class JDETracker(object): ...@@ -46,6 +46,9 @@ class JDETracker(object):
unconfirmed_thresh (float): linear assignment threshold of unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
""" """
def __init__(self, def __init__(self,
...@@ -55,7 +58,9 @@ class JDETracker(object): ...@@ -55,7 +58,9 @@ class JDETracker(object):
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
motion='KalmanFilter'): motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.det_thresh = det_thresh self.det_thresh = det_thresh
self.track_buffer = track_buffer self.track_buffer = track_buffer
self.min_box_area = min_box_area self.min_box_area = min_box_area
...@@ -63,6 +68,8 @@ class JDETracker(object): ...@@ -63,6 +68,8 @@ class JDETracker(object):
self.r_tracked_thresh = r_tracked_thresh self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh self.unconfirmed_thresh = unconfirmed_thresh
self.motion = motion self.motion = motion
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0 self.frame_id = 0
self.tracked_stracks = [] self.tracked_stracks = []
...@@ -96,6 +103,14 @@ class JDETracker(object): ...@@ -96,6 +103,14 @@ class JDETracker(object):
# removed. (Lost for some time lesser than the threshold for removing) # removed. (Lost for some time lesser than the threshold for removing)
removed_stracks = [] removed_stracks = []
remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres)
if remain_inds.shape[0] == 0:
pred_dets = paddle.zeros([0, 1])
pred_embs = paddle.zeros([0, 1])
else:
pred_dets = paddle.gather(pred_dets, remain_inds)
pred_embs = paddle.gather(pred_embs, remain_inds)
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]] # Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred = True if len(pred_dets) == 1 and paddle.sum( empty_pred = True if len(pred_dets) == 1 and paddle.sum(
pred_dets) == 0.0 else False pred_dets) == 0.0 else False
...@@ -125,7 +140,8 @@ class JDETracker(object): ...@@ -125,7 +140,8 @@ class JDETracker(object):
# Predict the current location with KF # Predict the current location with KF
STrack.multi_predict(strack_pool, self.motion) STrack.multi_predict(strack_pool, self.motion)
dists = matching.embedding_distance(strack_pool, detections) dists = matching.embedding_distance(
strack_pool, detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists, strack_pool, dists = matching.fuse_motion(self.motion, dists, strack_pool,
detections) detections)
# The dists is the list of distances of the detection with the tracks in strack_pool # The dists is the list of distances of the detection with the tracks in strack_pool
......
...@@ -16,8 +16,10 @@ from . import fpn ...@@ -16,8 +16,10 @@ from . import fpn
from . import yolo_fpn from . import yolo_fpn
from . import hrfpn from . import hrfpn
from . import ttf_fpn from . import ttf_fpn
from . import centernet_fpn
from .fpn import * from .fpn import *
from .yolo_fpn import * from .yolo_fpn import *
from .hrfpn import * from .hrfpn import *
from .ttf_fpn import * from .ttf_fpn import *
from .centernet_fpn import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn.initializer import KaimingUniform
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
def fill_up_weights(up):
weight = up.weight
f = math.ceil(weight.shape[2] / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(weight.shape[2]):
for j in range(weight.shape[3]):
weight[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, weight.shape[0]):
weight[c, 0, :, :] = weight[0, 0, :, :]
class IDAUp(nn.Layer):
def __init__(self, ch_ins, ch_out, up_strides, dcn_v2=True):
super(IDAUp, self).__init__()
for i in range(1, len(ch_ins)):
ch_in = ch_ins[i]
up_s = int(up_strides[i])
proj = nn.Sequential(
ConvNormLayer(
ch_in,
ch_out,
filter_size=3,
stride=1,
use_dcn=dcn_v2,
bias_on=dcn_v2,
norm_decay=None,
dcn_lr_scale=1.,
dcn_regularizer=None),
nn.ReLU())
node = nn.Sequential(
ConvNormLayer(
ch_out,
ch_out,
filter_size=3,
stride=1,
use_dcn=dcn_v2,
bias_on=dcn_v2,
norm_decay=None,
dcn_lr_scale=1.,
dcn_regularizer=None),
nn.ReLU())
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
up = nn.Conv2DTranspose(
ch_out,
ch_out,
kernel_size=up_s * 2,
weight_attr=param_attr,
stride=up_s,
padding=up_s // 2,
groups=ch_out,
bias_attr=False)
# TODO: uncomment fill_up_weights
#fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, 'node_' + str(i), node)
def forward(self, inputs, start_level, end_level):
for i in range(start_level + 1, end_level):
upsample = getattr(self, 'up_' + str(i - start_level))
project = getattr(self, 'proj_' + str(i - start_level))
inputs[i] = project(inputs[i])
inputs[i] = upsample(inputs[i])
node = getattr(self, 'node_' + str(i - start_level))
inputs[i] = node(paddle.add(inputs[i], inputs[i - 1]))
class DLAUp(nn.Layer):
def __init__(self, start_level, channels, scales, ch_in=None, dcn_v2=True):
super(DLAUp, self).__init__()
self.start_level = start_level
if ch_in is None:
ch_in = channels
self.channels = channels
channels = list(channels)
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(
self,
'ida_{}'.format(i),
IDAUp(
ch_in[j:],
channels[j],
scales[j:] // scales[j],
dcn_v2=dcn_v2))
scales[j + 1:] = scales[j]
ch_in[j + 1:] = [channels[j] for _ in channels[j + 1:]]
def forward(self, inputs):
out = [inputs[-1]] # start with 32
for i in range(len(inputs) - self.start_level - 1):
ida = getattr(self, 'ida_{}'.format(i))
ida(inputs, len(inputs) - i - 2, len(inputs))
out.insert(0, inputs[-1])
return out
@register
@serializable
class CenterNetDLAFPN(nn.Layer):
"""
Args:
in_channels (list): number of input feature channels from backbone.
[16, 32, 64, 128, 256, 512] by default, means the channels of DLA-34
down_ratio (int): the down ratio from images to heatmap, 4 by default
last_level (int): the last level of input feature fed into the upsamplng block
out_channel (int): the channel of the output feature, 0 by default means
the channel of the input feature whose down ratio is `down_ratio`
dcn_v2 (bool): whether use the DCNv2, true by default
"""
def __init__(self,
in_channels,
down_ratio=4,
last_level=5,
out_channel=0,
dcn_v2=True):
super(CenterNetDLAFPN, self).__init__()
self.first_level = int(np.log2(down_ratio))
self.down_ratio = down_ratio
self.last_level = last_level
scales = [2**i for i in range(len(in_channels[self.first_level:]))]
self.dla_up = DLAUp(
self.first_level,
in_channels[self.first_level:],
scales,
dcn_v2=dcn_v2)
self.out_channel = out_channel
if out_channel == 0:
self.out_channel = in_channels[self.first_level]
self.ida_up = IDAUp(
in_channels[self.first_level:self.last_level],
self.out_channel,
[2**i for i in range(self.last_level - self.first_level)],
dcn_v2=dcn_v2)
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape]}
def forward(self, body_feats):
dla_up_feats = self.dla_up(body_feats)
ida_up_feats = []
for i in range(self.last_level - self.first_level):
ida_up_feats.append(dla_up_feats[i].clone())
self.ida_up(ida_up_feats, 0, len(ida_up_feats))
return ida_up_feats[-1]
@property
def out_shape(self):
return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]
...@@ -18,6 +18,7 @@ import paddle.nn as nn ...@@ -18,6 +18,7 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly
from ppdet.modeling.layers import TTFBox
try: try:
from collections.abc import Sequence from collections.abc import Sequence
except Exception: except Exception:
...@@ -29,6 +30,7 @@ __all__ = [ ...@@ -29,6 +30,7 @@ __all__ = [
'FCOSPostProcess', 'FCOSPostProcess',
'S2ANetBBoxPostProcess', 'S2ANetBBoxPostProcess',
'JDEBBoxPostProcess', 'JDEBBoxPostProcess',
'CenterNetPostProcess',
] ]
...@@ -352,3 +354,95 @@ class JDEBBoxPostProcess(BBoxPostProcess): ...@@ -352,3 +354,95 @@ class JDEBBoxPostProcess(BBoxPostProcess):
nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32')) nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32'))
return boxes_idx, bbox_pred, bbox_num, nms_keep_idx return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
@register
class CenterNetPostProcess(TTFBox):
"""
Postprocess the model outputs to get final prediction:
1. Do NMS for heatmap to get top `max_per_img` bboxes.
2. Decode bboxes using center offset and box size.
3. Rescale decoded bboxes reference to the origin image shape.
Args:
max_per_img(int): the maximum number of predicted objects in a image,
500 by default.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default.
for_mot (bool): whether return other features used in tracking model.
"""
__shared__ = ['down_ratio']
def __init__(self,
max_per_img=500,
down_ratio=4,
regress_ltrb=True,
for_mot=False):
super(TTFBox, self).__init__()
self.max_per_img = max_per_img
self.down_ratio = down_ratio
self.regress_ltrb = regress_ltrb
self.for_mot = for_mot
def __call__(self, hm, wh, reg, im_shape, scale_factor):
heat = self._simple_nms(hm)
scores, inds, clses, ys, xs = self._topk(heat)
scores = paddle.tensor.unsqueeze(scores, [1])
clses = paddle.tensor.unsqueeze(clses, [1])
reg_t = paddle.transpose(reg, [0, 2, 3, 1])
# Like TTFBox, batch size is 1.
# TODO: support batch size > 1
reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]])
reg = paddle.gather(reg, inds)
xs = paddle.cast(xs, 'float32')
ys = paddle.cast(ys, 'float32')
xs = xs + reg[:, 0:1]
ys = ys + reg[:, 1:2]
wh_t = paddle.transpose(wh, [0, 2, 3, 1])
wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
wh = paddle.gather(wh, inds)
if self.regress_ltrb:
x1 = xs - wh[:, 0:1]
y1 = ys - wh[:, 1:2]
x2 = xs + wh[:, 2:3]
y2 = ys + wh[:, 3:4]
else:
x1 = xs - wh[:, 0:1] / 2
y1 = ys - wh[:, 1:2] / 2
x2 = xs + wh[:, 0:1] / 2
y2 = ys + wh[:, 1:2] / 2
n, c, feat_h, feat_w = paddle.shape(hm)
padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
x1 = x1 * self.down_ratio
y1 = y1 * self.down_ratio
x2 = x2 * self.down_ratio
y2 = y2 * self.down_ratio
x1 = x1 - padw
y1 = y1 - padh
x2 = x2 - padw
y2 = y2 - padh
bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
scale_y = scale_factor[:, 0:1]
scale_x = scale_factor[:, 1:2]
scale_expand = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=1)
boxes_shape = paddle.shape(bboxes)
boxes_shape.stop_gradient = True
scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand)
if self.for_mot:
results = paddle.concat([bboxes, scores, clses], axis=1)
return results, inds
else:
results = paddle.concat([clses, scores, bboxes], axis=1)
return results, paddle.shape(results)[0:1]
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from . import jde_embedding_head from . import jde_embedding_head
from . import pyramidal_embedding from . import pyramidal_embedding
from . import resnet from . import resnet
from . import fairmot_embedding_head
from .jde_embedding_head import * from .jde_embedding_head import *
from .pyramidal_embedding import * from .pyramidal_embedding import *
from .resnet import * from .resnet import *
from .fairmot_embedding_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import KaimingUniform, Uniform
from ppdet.core.workspace import register
from ppdet.modeling.heads.centernet_head import ConvLayer
__all__ = ['FairMOTEmbeddingHead']
@register
class FairMOTEmbeddingHead(nn.Layer):
"""
Args:
in_channels (int): the channel number of input to FairMOTEmbeddingHead.
ch_head (int): the channel of features before fed into embedding, 256 by default.
ch_emb (int): the channel of the embedding feature, 128 by default.
num_identifiers (int): the number of identifiers, 14455 by default.
"""
def __init__(self,
in_channels,
ch_head=256,
ch_emb=128,
num_identifiers=14455):
super(FairMOTEmbeddingHead, self).__init__()
self.reid = nn.Sequential(
ConvLayer(
in_channels, ch_head, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
ch_head, ch_emb, kernel_size=1, stride=1, padding=0, bias=True))
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
bound = 1 / math.sqrt(ch_emb)
bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
self.classifier = nn.Linear(
ch_emb,
num_identifiers,
weight_attr=param_attr,
bias_attr=bias_attr)
self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
# When num_identifiers is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log(
num_identifiers - 1) if num_identifiers > 1 else 1
@classmethod
def from_config(cls, cfg, input_shape):
if isinstance(input_shape, (list, tuple)):
input_shape = input_shape[0]
return {'in_channels': input_shape.channels}
def forward(self, feat, inputs):
reid_feat = self.reid(feat)
if self.training:
loss = self.get_loss(reid_feat, inputs)
return loss
else:
reid_feat = F.normalize(reid_feat)
return reid_feat
def get_loss(self, feat, inputs):
index = inputs['index']
mask = inputs['index_mask']
target = inputs['reid']
target = paddle.masked_select(target, mask > 0)
target = paddle.unsqueeze(target, 1)
feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
feat_n, feat_h, feat_w, feat_c = feat.shape
feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(feat_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
feat = paddle.gather_nd(feat, index=index)
mask = paddle.unsqueeze(mask, axis=2)
mask = paddle.expand_as(mask, feat)
mask.stop_gradient = True
feat = paddle.masked_select(feat, mask > 0)
feat = paddle.reshape(feat, shape=[-1, feat_c])
feat = F.normalize(feat)
feat = self.emb_scale * feat
logit = self.classifier(feat)
target.stop_gradient = True
loss = self.reid_loss(logit, target)
valid = (target != self.reid_loss.ignore_index)
valid.stop_gradient = True
count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
count.stop_gradient = True
if count > 0:
loss = loss / count
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册