未验证 提交 277bc13f 编写于 作者: F FlyingQianMM 提交者: GitHub

[MOT] add FairMOT (#2994)

* add fairmot

* rename pretrain_weights

* add comments; reuse focal loss and convnorm

* update comments
上级 5ad5a819
......@@ -12,7 +12,7 @@ MOTDataZoo: {
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot17.train', 'caltech.train', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
EvalDataset:
......
# FairMOT (FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Getting Start](#Getting_Start)
- [Citations](#Citations)
## Introduction
FairMOT focuses on accomplishing the detection and re-identification in a single network to improve the inference speed, presents a simple baseline which consists of two homogeneous branches to predict pixel-wise objectness scores and re-ID features. The achieved fairness between the two tasks allows FairMOT to obtain high levels of detection and tracking accuracy.
## Model Zoo
### Results on MOT-16 train set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | download | config |
| :-----------------| :------- | :----: | :----: | :---: | :----: | :---: |:---: | :---: |
| DLA-34(paper) | 1088x608 | 83.3 | 81.9 | 544 | 3822 | 14095 | ---- | ---- |
| DLA-34 | 1088x608 | 83.4 | 82.7 | 517 | 4077 | 13761 | [model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml) |
### Results on MOT-16 test set
| backbone | input shape | MOTA | IDF1 | IDS | MT | ML | download | config |
| :-----------------| :------- | :----: | :----: | :---: | :----: | :---: | :---: | :---: |
| DLA-34(paper) | 1088x608 | 74.9 72.8 1074 44.7% 15.9% | ---- | ---- |
| DLA-34 | 1088x608 | 74.7 | 72.8 | 1044 | 41.9% | 19.1% |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml) |
**Notes:**
FairMOT used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches.
## Getting Start
### 1. Training
Training FairMOT on 2 GPUs with following command
```bash
python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608/ --gpus 0,1 tools/train.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml &>fairmot_dla34_30e_1088x608.log 2>&1 &
```
### 2. Evaluation
Evaluating the track performance of FairMOT on val dataset in single GPU with following commands:
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams
# use saved checkpoint in training
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=output/fairmot_dla34_30e_1088x608/model_final
```
## Citations
```
@article{zhang2020fair,
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
```
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker
CenterNet:
backbone: DLA
neck: CenterNetDLAFPN
head: CenterNetHead
post_process: CenterNetPostProcess
for_mot: True
CenterNetPostProcess:
for_mot: True
JDETracker:
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
worker_num: 4
TrainReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {to_rgb: False}
- AugmentHSV: {is_bgr: True}
- LetterBoxResize: {target_size: [608, 1088]}
- MOTRandomAffine: {reject_outside: False}
- RandomFlip: {}
- BboxXYXY2XYWH: {}
- NormalizeBox: {}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
- Permute: {to_rgb: True}
batch_transforms:
- Gt2FairMOTTarget: {}
batch_size: 6
shuffle: True
drop_last: True
EvalMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {to_rgb: False}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
- Permute: {to_rgb: True}
batch_size: 1
epoch: 30
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [20,]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/optimizer_30e.yml',
'_base_/fairmot_dla34.yml',
'_base_/fairmot_reader_1088x608.yml',
]
metric: MOT
weights: output/fairmot_dla34_30e_1088x608/model_final
......@@ -347,6 +347,7 @@ class MOTVideoDataset(DetDataset):
def _load_video_images(self):
self.cap = cv2.VideoCapture(self.video_file)
self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
logger.info('Length of the video: {:d} frames'.format(self.vn))
res = True
ct = 0
......
......@@ -25,14 +25,19 @@ from numbers import Integral
import cv2
import copy
import numpy as np
import math
from .operators import BaseOperator, register_op
from .batch_operators import Gt2TTFTarget
from ppdet.modeling.bbox_utils import bbox_iou_np_expand
from ppdet.core.workspace import serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax']
__all__ = [
'LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax',
'Gt2FairMOTTarget'
]
@register_op
......@@ -367,3 +372,117 @@ class Gt2JDETargetMax(BaseOperator):
sample['tbox{}'.format(i)] = tbox
sample['tconf{}'.format(i)] = tconf
sample['tide{}'.format(i)] = tid
class Gt2FairMOTTarget(Gt2TTFTarget):
__shared__ = ['num_classes']
"""
Generate FairMOT targets by ground truth data.
Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
1. the gaussian kernal radius to generate a heatmap.
2. the targets needed during traing.
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
"""
def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
super(Gt2TTFTarget, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.max_objs = max_objs
def __call__(self, samples, context=None):
for b_id, sample in enumerate(samples):
output_h = sample['image'].shape[1] // self.down_ratio
output_w = sample['image'].shape[2] // self.down_ratio
heatmap = np.zeros(
(self.num_classes, output_h, output_w), dtype='float32')
bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
index = np.zeros((self.max_objs, ), dtype=np.int64)
index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
reid = np.zeros((self.max_objs, ), dtype=np.int64)
bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
gt_ide = sample['gt_ide']
for k in range(len(gt_bbox)):
cls_id = gt_class[k][0]
bbox = gt_bbox[k]
ide = gt_ide[k][0]
bbox[[0, 2]] = bbox[[0, 2]] * output_w
bbox[[1, 3]] = bbox[[1, 3]] * output_h
bbox_amodal = copy.deepcopy(bbox)
bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
bbox[0] = np.clip(bbox[0], 0, output_w - 1)
bbox[1] = np.clip(bbox[1], 0, output_h - 1)
h = bbox[3]
w = bbox[2]
bbox_xy = copy.deepcopy(bbox)
bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
if h > 0 and w > 0:
radius = self.gaussian_radius((math.ceil(h), math.ceil(w)))
radius = max(0, int(radius))
ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
ct_int = ct.astype(np.int32)
self.draw_truncate_gaussian(heatmap[cls_id], ct_int, radius,
radius)
bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
index[k] = ct_int[1] * output_w + ct_int[0]
center_offset[k] = ct - ct_int
index_mask[k] = 1
reid[k] = ide
bbox_xys[k] = bbox_xy
sample['heatmap'] = heatmap
sample['index'] = index
sample['offset'] = center_offset
sample['size'] = bbox_size
sample['index_mask'] = index_mask
sample['reid'] = reid
sample['bbox_xys'] = bbox_xys
sample.pop('is_crowd', None)
sample.pop('difficult', None)
sample.pop('gt_class', None)
sample.pop('gt_bbox', None)
sample.pop('gt_score', None)
sample.pop('gt_ide', None)
return samples
def gaussian_radius(self, det_size, min_overlap=0.7):
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)
......@@ -107,10 +107,12 @@ class BaseOperator(object):
@register_op
class Decode(BaseOperator):
def __init__(self):
def __init__(self, to_rgb=True):
""" Transform the image data to numpy format following the rgb format
"""
super(Decode, self).__init__()
# TODO: remove this parameter
self.to_rgb = to_rgb
def apply(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is"""
......@@ -124,7 +126,8 @@ class Decode(BaseOperator):
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
if 'keep_ori_im' in sample and sample['keep_ori_im']:
sample['ori_image'] = im
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
sample['image'] = im
if 'h' not in sample:
......@@ -151,14 +154,18 @@ class Decode(BaseOperator):
@register_op
class Permute(BaseOperator):
def __init__(self):
def __init__(self, to_rgb=False):
"""
Change the channel to be (C, H, W)
"""
super(Permute, self).__init__()
# TODO: remove this parameter
self.to_rgb = to_rgb
def apply(self, sample, context=None):
im = sample['image']
if self.to_rgb:
im = np.ascontiguousarray(im[:, :, ::-1])
im = im.transpose((2, 0, 1))
sample['image'] = im
return sample
......@@ -2076,29 +2083,39 @@ class Norm2PixelBbox(BaseOperator):
@register_op
class MOTRandomAffine(BaseOperator):
"""
Affine transform to image and coords to achieve the rotate, scale and
shift effect for training image.
Args:
degrees (list[2]): the rotate range to apply, transform range is [min, max]
translate (list[2]): the translate range to apply, ransform range is [min, max]
scale (list[2]): the scale range to apply, transform range is [min, max]
shear (list[2]): the shear range to apply, transform range is [min, max]
borderValue (list[3]): value used in case of a constant border when appling
the perspective transformation
reject_outside (bool): reject warped bounding bboxes outside of image
Returns:
records(dict): contain the image and coords after tranformed
"""
def __init__(self,
degrees=(-5, 5),
translate=(0.10, 0.10),
scale=(0.50, 1.20),
shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5)):
"""
Affine transform to image and coords to achieve the rotate, scale and
shift effect for training image.
borderValue=(127.5, 127.5, 127.5),
reject_outside=True):
Args:
degrees (tuple): rotation value
translate (tuple): xy coords translation value
scale (tuple): scale value
shear (tuple): shear value
borderValue (tuple): border color value
"""
super(MOTRandomAffine, self).__init__()
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.borderValue = borderValue
self.reject_outside = reject_outside
def apply(self, sample, context=None):
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
......@@ -2171,10 +2188,11 @@ class MOTRandomAffine(BaseOperator):
(x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
# reject warped points outside of image
np.clip(xy[:, 0], 0, width, out=xy[:, 0])
np.clip(xy[:, 2], 0, width, out=xy[:, 2])
np.clip(xy[:, 1], 0, height, out=xy[:, 1])
np.clip(xy[:, 3], 0, height, out=xy[:, 3])
if self.reject_outside:
np.clip(xy[:, 0], 0, width, out=xy[:, 0])
np.clip(xy[:, 2], 0, width, out=xy[:, 2])
np.clip(xy[:, 1], 0, height, out=xy[:, 1])
np.clip(xy[:, 3], 0, height, out=xy[:, 3])
w = xy[:, 2] - xy[:, 0]
h = xy[:, 3] - xy[:, 1]
area = w * h
......
......@@ -66,6 +66,10 @@ class Trainer(object):
cfg['JDEEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities
if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities
# build model
if 'model' not in self.cfg:
self.model = create(cfg.architecture)
......@@ -223,7 +227,10 @@ class Trainer(object):
return
self.start_epoch = 0
if hasattr(self.model, 'detector'):
load_pretrain_weight(self.model.detector, weights)
if self.model.__class__.__name__ == 'FairMOT':
load_pretrain_weight(self.model, weights)
else:
load_pretrain_weight(self.model.detector, weights)
else:
load_pretrain_weight(self.model, weights)
logger.debug("Load weights {} to start training".format(weights))
......
......@@ -19,6 +19,8 @@ from . import keypoint_hrhrnet
from . import keypoint_hrnet
from . import jde
from . import deepsort
from . import fairmot
from . import centernet
from .meta_arch import *
from .faster_rcnn import *
......@@ -34,3 +36,5 @@ from .keypoint_hrhrnet import *
from .keypoint_hrnet import *
from .jde import *
from .deepsort import *
from .fairmot import *
from .centernet import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['CenterNet']
@register
class CenterNet(BaseArch):
"""
CenterNet network, see http://arxiv.org/abs/1904.07850
Args:
backbone (object): backbone instance
neck (object): 'CenterDLAFPN' instance
head (object): 'CenterHead' instance
post_process (object): 'CenterNetPostProcess' instance
for_mot (bool): whether return other features used in tracking model
"""
__category__ = 'architecture'
__inject__ = ['post_process']
def __init__(self,
backbone='DLA',
neck='CenterDLAFPN',
head='CenterHead',
post_process='CenterNetPostProcess',
for_mot=False):
super(CenterNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.post_process = post_process
self.for_mot = for_mot
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {'backbone': backbone, 'neck': neck, "head": head}
def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feat = self.neck(body_feats)
head_out = self.head(neck_feat, self.inputs)
if self.for_mot:
head_out.update({'neck_feat': neck_feat})
return head_out
def get_pred(self):
head_out = self._forward()
if self.for_mot:
bbox, bbox_inds = self.post_process(
head_out['heatmap'],
head_out['size'],
head_out['offset'],
im_shape=self.inputs['im_shape'],
scale_factor=self.inputs['scale_factor'])
output = {
"bbox": bbox,
"bbox_inds": bbox_inds,
"neck_feat": head_out['neck_feat']
}
else:
bbox, bbox_num = self.post_process(
head_out['heatmap'],
head_out['size'],
head_out['offset'],
im_shape=self.inputs['im_shape'],
scale_factor=self.inputs['scale_factor'])
output = {
"bbox": bbox,
"bbox_num": bbox_num,
}
return output
def get_loss(self):
return self._forward()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['FairMOT']
@register
class FairMOT(BaseArch):
"""
FairMOT network, see http://arxiv.org/abs/2004.01888
Args:
detector (object): 'CenterNet' instance
reid (object): 'FairMOTEmbeddingHead' instance
tracker (object): 'JDETracker' instance
loss (object): 'FairMOTLoss' instance
"""
__category__ = 'architecture'
__inject__ = ['loss']
def __init__(self,
detector='CenterNet',
reid='FairMOTEmbeddingHead',
tracker='JDETracker',
loss='FairMOTLoss'):
super(FairMOT, self).__init__()
self.detector = detector
self.reid = reid
self.tracker = tracker
self.loss = loss
@classmethod
def from_config(cls, cfg, *args, **kwargs):
detector = create(cfg['detector'])
kwargs = {'input_shape': detector.neck.out_shape}
reid = create(cfg['reid'], **kwargs)
loss = create(cfg['loss'])
tracker = create(cfg['tracker'])
return {
'detector': detector,
'reid': reid,
'loss': loss,
'tracker': tracker
}
def _forward(self):
loss = dict()
# det_outs keys:
# train: det_loss, heatmap_loss, size_loss, offset_loss, neck_feat
# eval/infer: bbox, bbox_inds, neck_feat
det_outs = self.detector(self.inputs)
neck_feat = det_outs['neck_feat']
if self.training:
reid_loss = self.reid(neck_feat, self.inputs)
det_loss = det_outs['det_loss']
loss = self.loss(det_loss, reid_loss)
loss.update({
'heatmap_loss': det_outs['heatmap_loss'],
'size_loss': det_outs['size_loss'],
'offset_loss': det_outs['offset_loss'],
'reid_loss': reid_loss
})
return loss
else:
embedding = self.reid(neck_feat, self.inputs)
bbox_inds = det_outs['bbox_inds']
embedding = paddle.transpose(embedding, [0, 2, 3, 1])
embedding = paddle.reshape(embedding,
[-1, paddle.shape(embedding)[-1]])
id_feature = paddle.gather(embedding, bbox_inds)
dets = det_outs['bbox']
id_feature = id_feature
# Note: the tracker only considers batch_size=1 and num_classses=1
online_targets = self.tracker.update(dets, id_feature)
return online_targets
def get_pred(self):
output = self._forward()
return output
def get_loss(self):
loss = self._forward()
return loss
......@@ -22,6 +22,7 @@ from . import blazenet
from . import ghostnet
from . import senet
from . import res2net
from . import dla
from .vgg import *
from .resnet import *
......@@ -33,3 +34,4 @@ from .blazenet import *
from .ghostnet import *
from .senet import *
from .res2net import *
from .dla import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512])}
class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = ConvNormLayer(
ch_in,
ch_out,
filter_size=3,
stride=stride,
bias_on=False,
norm_decay=None)
self.conv2 = ConvNormLayer(
ch_out,
ch_out,
filter_size=3,
stride=1,
bias_on=False,
norm_decay=None)
def forward(self, inputs, residual=None):
if residual is None:
residual = inputs
out = self.conv1(inputs)
out = F.relu(out)
out = self.conv2(out)
out = paddle.add(x=out, y=residual)
out = F.relu(out)
return out
class Root(nn.Layer):
def __init__(self, ch_in, ch_out, kernel_size, residual):
super(Root, self).__init__()
self.conv = ConvNormLayer(
ch_in,
ch_out,
filter_size=1,
stride=1,
bias_on=False,
norm_decay=None)
self.residual = residual
def forward(self, inputs):
children = inputs
out = self.conv(paddle.concat(inputs, axis=1))
if self.residual:
out = paddle.add(x=out, y=children[0])
out = F.relu(out)
return out
class Tree(nn.Layer):
def __init__(self,
level,
block,
ch_in,
ch_out,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
root_residual=False):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * ch_out
if level_root:
root_dim += ch_in
if level == 1:
self.tree1 = block(ch_in, ch_out, stride)
self.tree2 = block(ch_out, ch_out, 1)
else:
self.tree1 = Tree(
level - 1,
block,
ch_in,
ch_out,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
root_residual=root_residual)
self.tree2 = Tree(
level - 1,
block,
ch_out,
ch_out,
1,
root_dim=root_dim + ch_out,
root_kernel_size=root_kernel_size,
root_residual=root_residual)
if level == 1:
self.root = Root(root_dim, ch_out, root_kernel_size, root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.level = level
if stride > 1:
self.downsample = nn.MaxPool2D(stride, stride=stride)
if ch_in != ch_out:
self.project = ConvNormLayer(
ch_in,
ch_out,
filter_size=1,
stride=1,
bias_on=False,
norm_decay=None)
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.level == 1:
x2 = self.tree2(x1)
x = self.root([x2, x1] + children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
@register
@serializable
class DLA(nn.Layer):
"""
DLA, see https://arxiv.org/pdf/1707.06484.pdf
Args:
depth (int): DLA depth, should be 34.
residual_root (bool): whether use a reidual layer in the root block
"""
def __init__(self, depth=34, residual_root=False):
super(DLA, self).__init__()
levels, channels = DLA_cfg[depth]
if depth == 34:
block = BasicBlock
self.channels = channels
self.base_layer = nn.Sequential(
ConvNormLayer(
3,
channels[0],
filter_size=7,
stride=1,
bias_on=False,
norm_decay=None),
nn.ReLU())
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(
channels[0], channels[1], levels[1], stride=2)
self.level2 = Tree(
levels[2],
block,
channels[1],
channels[2],
2,
level_root=False,
root_residual=residual_root)
self.level3 = Tree(
levels[3],
block,
channels[2],
channels[3],
2,
level_root=True,
root_residual=residual_root)
self.level4 = Tree(
levels[4],
block,
channels[3],
channels[4],
2,
level_root=True,
root_residual=residual_root)
self.level5 = Tree(
levels[5],
block,
channels[4],
channels[5],
2,
level_root=True,
root_residual=residual_root)
def _make_conv_level(self, ch_in, ch_out, conv_num, stride=1):
modules = []
for i in range(conv_num):
modules.extend([
ConvNormLayer(
ch_in,
ch_out,
filter_size=3,
stride=stride if i == 0 else 1,
bias_on=False,
norm_decay=None), nn.ReLU()
])
ch_in = ch_out
return nn.Sequential(*modules)
@property
def out_shape(self):
return [ShapeSpec(channels=self.channels[i]) for i in range(6)]
def forward(self, inputs):
outs = []
im = inputs['image']
feats = self.base_layer(im)
for i in range(6):
feats = getattr(self, 'level{}'.format(i))(feats)
outs.append(feats)
return outs
......@@ -24,6 +24,7 @@ from . import cascade_head
from . import face_head
from . import s2anet_head
from . import keypoint_hrhrnet_head
from . import centernet_head
from .bbox_head import *
from .mask_head import *
......@@ -37,3 +38,4 @@ from .cascade_head import *
from .face_head import *
from .s2anet_head import *
from .keypoint_hrhrnet_head import *
from .centernet_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import KaimingUniform
from ppdet.core.workspace import register
from ppdet.modeling.losses import CTFocalLoss
class ConvLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False):
super(ConvLayer, self).__init__()
bias_attr = False
fan_in = ch_in * kernel_size**2
bound = 1 / math.sqrt(fan_in)
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
if bias:
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.Uniform(-bound, bound))
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=param_attr,
bias_attr=bias_attr)
def forward(self, inputs):
out = self.conv(inputs)
return out
@register
class CenterNetHead(nn.Layer):
"""
Args:
in_channels (int): the channel number of input to CenterNetHead.
num_classes (int): the number of classes, 80 by default.
head_planes (int): the channel number in all head, 256 by default.
heatmap_weight (float): the weight of heatmap loss, 1 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default
size_weight (float): the weight of box size loss, 0.1 by default.
offset_weight (float): the weight of center offset loss, 1 by default.
"""
__shared__ = ['num_classes']
def __init__(self,
in_channels,
num_classes=80,
head_planes=256,
heatmap_weight=1,
regress_ltrb=True,
size_weight=0.1,
offset_weight=1):
super(CenterNetHead, self).__init__()
self.weights = {
'heatmap': heatmap_weight,
'size': size_weight,
'offset': offset_weight
}
self.heatmap = nn.Sequential(
ConvLayer(
in_channels, head_planes, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
head_planes,
num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.heatmap[2].conv.bias[:] = -2.19
self.size = nn.Sequential(
ConvLayer(
in_channels, head_planes, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
head_planes,
4 if regress_ltrb else 2,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.offset = nn.Sequential(
ConvLayer(
in_channels, head_planes, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True))
self.focal_loss = CTFocalLoss()
@classmethod
def from_config(cls, cfg, input_shape):
if isinstance(input_shape, (list, tuple)):
input_shape = input_shape[0]
return {'in_channels': input_shape.channels}
def forward(self, feat, inputs):
heatmap = self.heatmap(feat)
size = self.size(feat)
offset = self.offset(feat)
if self.training:
loss = self.get_loss(heatmap, size, offset, self.weights, inputs)
return loss
else:
heatmap = F.sigmoid(heatmap)
return {'heatmap': heatmap, 'size': size, 'offset': offset}
def get_loss(self, heatmap, size, offset, weights, inputs):
heatmap_target = inputs['heatmap']
size_target = inputs['size']
offset_target = inputs['offset']
index = inputs['index']
mask = inputs['index_mask']
heatmap = paddle.clip(F.sigmoid(heatmap), 1e-4, 1 - 1e-4)
heatmap_loss = self.focal_loss(heatmap, heatmap_target)
size = paddle.transpose(size, perm=[0, 2, 3, 1])
size_n, size_h, size_w, size_c = size.shape
size = paddle.reshape(size, shape=[size_n, -1, size_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(size_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
pos_size = paddle.gather_nd(size, index=index)
mask = paddle.unsqueeze(mask, axis=2)
size_mask = paddle.expand_as(mask, pos_size)
size_mask = paddle.cast(size_mask, dtype=pos_size.dtype)
pos_num = size_mask.sum()
size_mask.stop_gradient = True
size_target.stop_gradient = True
size_loss = F.l1_loss(
pos_size * size_mask, size_target * size_mask, reduction='sum')
size_loss = size_loss / (pos_num + 1e-4)
offset = paddle.transpose(offset, perm=[0, 2, 3, 1])
offset_n, offset_h, offset_w, offset_c = offset.shape
offset = paddle.reshape(offset, shape=[offset_n, -1, offset_c])
pos_offset = paddle.gather_nd(offset, index=index)
offset_mask = paddle.expand_as(mask, pos_offset)
offset_mask = paddle.cast(offset_mask, dtype=pos_offset.dtype)
pos_num = offset_mask.sum()
offset_mask.stop_gradient = True
offset_target.stop_gradient = True
offset_loss = F.l1_loss(
pos_offset * offset_mask,
offset_target * offset_mask,
reduction='sum')
offset_loss = offset_loss / (pos_num + 1e-4)
det_loss = weights['heatmap'] * heatmap_loss + weights[
'size'] * size_loss + weights['offset'] * offset_loss
return {
'det_loss': det_loss,
'heatmap_loss': heatmap_loss,
'size_loss': size_loss,
'offset_loss': offset_loss
}
......@@ -52,7 +52,9 @@ class DeformableConvV2(nn.Layer):
bias_attr=None,
lr_scale=1,
regularizer=None,
skip_quant=False):
skip_quant=False,
dcn_bias_regularizer=L2Decay(0.),
dcn_bias_lr_scale=2.):
super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2
self.mask_channel = kernel_size**2
......@@ -79,8 +81,8 @@ class DeformableConvV2(nn.Layer):
# in FCOS-DCN head, specifically need learning_rate and regularizer
dcn_bias_attr = ParamAttr(
initializer=Constant(value=0),
regularizer=L2Decay(0.),
learning_rate=2.)
regularizer=dcn_bias_regularizer,
learning_rate=dcn_bias_lr_scale)
else:
# in ResNet backbone, do not need bias
dcn_bias_attr = False
......@@ -122,7 +124,9 @@ class ConvNormLayer(nn.Layer):
freeze_norm=False,
initializer=Normal(
mean=0., std=0.01),
skip_quant=False):
skip_quant=False,
dcn_lr_scale=2.,
dcn_regularizer=L2Decay(0.)):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
......@@ -157,15 +161,17 @@ class ConvNormLayer(nn.Layer):
weight_attr=ParamAttr(
initializer=initializer, learning_rate=1.),
bias_attr=True,
lr_scale=2.,
regularizer=L2Decay(norm_decay),
lr_scale=dcn_lr_scale,
regularizer=dcn_regularizer,
skip_quant=skip_quant)
norm_lr = 0. if freeze_norm else 1.
param_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
bias_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
if norm_type == 'bn':
self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
......
......@@ -21,6 +21,7 @@ from . import solov2_loss
from . import ctfocal_loss
from . import keypoint_loss
from . import jde_loss
from . import fairmot_loss
from .yolo_loss import *
from .iou_aware_loss import *
......@@ -31,3 +32,4 @@ from .solov2_loss import *
from .ctfocal_loss import *
from .keypoint_loss import *
from .jde_loss import *
from .fairmot_loss import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.nn.initializer import Constant
from ppdet.core.workspace import register
__all__ = ['FairMOTLoss']
@register
class FairMOTLoss(nn.Layer):
def __init__(self):
super(FairMOTLoss, self).__init__()
self.det_weight = self.create_parameter(
shape=[1], default_initializer=Constant(-1.85))
self.reid_weight = self.create_parameter(
shape=[1], default_initializer=Constant(-1.05))
def forward(self, det_loss, reid_loss):
loss = paddle.exp(-self.det_weight) * det_loss + paddle.exp(
-self.reid_weight) * reid_loss + (self.det_weight + self.reid_weight
)
loss *= 0.5
return {'loss': loss}
......@@ -46,6 +46,9 @@ class JDETracker(object):
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
......@@ -55,7 +58,9 @@ class JDETracker(object):
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
motion='KalmanFilter'):
motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.det_thresh = det_thresh
self.track_buffer = track_buffer
self.min_box_area = min_box_area
......@@ -63,6 +68,8 @@ class JDETracker(object):
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
self.motion = motion
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0
self.tracked_stracks = []
......@@ -96,6 +103,14 @@ class JDETracker(object):
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks = []
remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres)
if remain_inds.shape[0] == 0:
pred_dets = paddle.zeros([0, 1])
pred_embs = paddle.zeros([0, 1])
else:
pred_dets = paddle.gather(pred_dets, remain_inds)
pred_embs = paddle.gather(pred_embs, remain_inds)
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred = True if len(pred_dets) == 1 and paddle.sum(
pred_dets) == 0.0 else False
......@@ -125,7 +140,8 @@ class JDETracker(object):
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.motion)
dists = matching.embedding_distance(strack_pool, detections)
dists = matching.embedding_distance(
strack_pool, detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists, strack_pool,
detections)
# The dists is the list of distances of the detection with the tracks in strack_pool
......
......@@ -16,8 +16,10 @@ from . import fpn
from . import yolo_fpn
from . import hrfpn
from . import ttf_fpn
from . import centernet_fpn
from .fpn import *
from .yolo_fpn import *
from .hrfpn import *
from .ttf_fpn import *
from .centernet_fpn import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn.initializer import KaimingUniform
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
def fill_up_weights(up):
weight = up.weight
f = math.ceil(weight.shape[2] / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(weight.shape[2]):
for j in range(weight.shape[3]):
weight[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, weight.shape[0]):
weight[c, 0, :, :] = weight[0, 0, :, :]
class IDAUp(nn.Layer):
def __init__(self, ch_ins, ch_out, up_strides, dcn_v2=True):
super(IDAUp, self).__init__()
for i in range(1, len(ch_ins)):
ch_in = ch_ins[i]
up_s = int(up_strides[i])
proj = nn.Sequential(
ConvNormLayer(
ch_in,
ch_out,
filter_size=3,
stride=1,
use_dcn=dcn_v2,
bias_on=dcn_v2,
norm_decay=None,
dcn_lr_scale=1.,
dcn_regularizer=None),
nn.ReLU())
node = nn.Sequential(
ConvNormLayer(
ch_out,
ch_out,
filter_size=3,
stride=1,
use_dcn=dcn_v2,
bias_on=dcn_v2,
norm_decay=None,
dcn_lr_scale=1.,
dcn_regularizer=None),
nn.ReLU())
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
up = nn.Conv2DTranspose(
ch_out,
ch_out,
kernel_size=up_s * 2,
weight_attr=param_attr,
stride=up_s,
padding=up_s // 2,
groups=ch_out,
bias_attr=False)
# TODO: uncomment fill_up_weights
#fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, 'node_' + str(i), node)
def forward(self, inputs, start_level, end_level):
for i in range(start_level + 1, end_level):
upsample = getattr(self, 'up_' + str(i - start_level))
project = getattr(self, 'proj_' + str(i - start_level))
inputs[i] = project(inputs[i])
inputs[i] = upsample(inputs[i])
node = getattr(self, 'node_' + str(i - start_level))
inputs[i] = node(paddle.add(inputs[i], inputs[i - 1]))
class DLAUp(nn.Layer):
def __init__(self, start_level, channels, scales, ch_in=None, dcn_v2=True):
super(DLAUp, self).__init__()
self.start_level = start_level
if ch_in is None:
ch_in = channels
self.channels = channels
channels = list(channels)
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(
self,
'ida_{}'.format(i),
IDAUp(
ch_in[j:],
channels[j],
scales[j:] // scales[j],
dcn_v2=dcn_v2))
scales[j + 1:] = scales[j]
ch_in[j + 1:] = [channels[j] for _ in channels[j + 1:]]
def forward(self, inputs):
out = [inputs[-1]] # start with 32
for i in range(len(inputs) - self.start_level - 1):
ida = getattr(self, 'ida_{}'.format(i))
ida(inputs, len(inputs) - i - 2, len(inputs))
out.insert(0, inputs[-1])
return out
@register
@serializable
class CenterNetDLAFPN(nn.Layer):
"""
Args:
in_channels (list): number of input feature channels from backbone.
[16, 32, 64, 128, 256, 512] by default, means the channels of DLA-34
down_ratio (int): the down ratio from images to heatmap, 4 by default
last_level (int): the last level of input feature fed into the upsamplng block
out_channel (int): the channel of the output feature, 0 by default means
the channel of the input feature whose down ratio is `down_ratio`
dcn_v2 (bool): whether use the DCNv2, true by default
"""
def __init__(self,
in_channels,
down_ratio=4,
last_level=5,
out_channel=0,
dcn_v2=True):
super(CenterNetDLAFPN, self).__init__()
self.first_level = int(np.log2(down_ratio))
self.down_ratio = down_ratio
self.last_level = last_level
scales = [2**i for i in range(len(in_channels[self.first_level:]))]
self.dla_up = DLAUp(
self.first_level,
in_channels[self.first_level:],
scales,
dcn_v2=dcn_v2)
self.out_channel = out_channel
if out_channel == 0:
self.out_channel = in_channels[self.first_level]
self.ida_up = IDAUp(
in_channels[self.first_level:self.last_level],
self.out_channel,
[2**i for i in range(self.last_level - self.first_level)],
dcn_v2=dcn_v2)
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape]}
def forward(self, body_feats):
dla_up_feats = self.dla_up(body_feats)
ida_up_feats = []
for i in range(self.last_level - self.first_level):
ida_up_feats.append(dla_up_feats[i].clone())
self.ida_up(ida_up_feats, 0, len(ida_up_feats))
return ida_up_feats[-1]
@property
def out_shape(self):
return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]
......@@ -18,6 +18,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly
from ppdet.modeling.layers import TTFBox
try:
from collections.abc import Sequence
except Exception:
......@@ -29,6 +30,7 @@ __all__ = [
'FCOSPostProcess',
'S2ANetBBoxPostProcess',
'JDEBBoxPostProcess',
'CenterNetPostProcess',
]
......@@ -352,3 +354,95 @@ class JDEBBoxPostProcess(BBoxPostProcess):
nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32'))
return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
@register
class CenterNetPostProcess(TTFBox):
"""
Postprocess the model outputs to get final prediction:
1. Do NMS for heatmap to get top `max_per_img` bboxes.
2. Decode bboxes using center offset and box size.
3. Rescale decoded bboxes reference to the origin image shape.
Args:
max_per_img(int): the maximum number of predicted objects in a image,
500 by default.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default.
for_mot (bool): whether return other features used in tracking model.
"""
__shared__ = ['down_ratio']
def __init__(self,
max_per_img=500,
down_ratio=4,
regress_ltrb=True,
for_mot=False):
super(TTFBox, self).__init__()
self.max_per_img = max_per_img
self.down_ratio = down_ratio
self.regress_ltrb = regress_ltrb
self.for_mot = for_mot
def __call__(self, hm, wh, reg, im_shape, scale_factor):
heat = self._simple_nms(hm)
scores, inds, clses, ys, xs = self._topk(heat)
scores = paddle.tensor.unsqueeze(scores, [1])
clses = paddle.tensor.unsqueeze(clses, [1])
reg_t = paddle.transpose(reg, [0, 2, 3, 1])
# Like TTFBox, batch size is 1.
# TODO: support batch size > 1
reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]])
reg = paddle.gather(reg, inds)
xs = paddle.cast(xs, 'float32')
ys = paddle.cast(ys, 'float32')
xs = xs + reg[:, 0:1]
ys = ys + reg[:, 1:2]
wh_t = paddle.transpose(wh, [0, 2, 3, 1])
wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
wh = paddle.gather(wh, inds)
if self.regress_ltrb:
x1 = xs - wh[:, 0:1]
y1 = ys - wh[:, 1:2]
x2 = xs + wh[:, 2:3]
y2 = ys + wh[:, 3:4]
else:
x1 = xs - wh[:, 0:1] / 2
y1 = ys - wh[:, 1:2] / 2
x2 = xs + wh[:, 0:1] / 2
y2 = ys + wh[:, 1:2] / 2
n, c, feat_h, feat_w = paddle.shape(hm)
padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
x1 = x1 * self.down_ratio
y1 = y1 * self.down_ratio
x2 = x2 * self.down_ratio
y2 = y2 * self.down_ratio
x1 = x1 - padw
y1 = y1 - padh
x2 = x2 - padw
y2 = y2 - padh
bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
scale_y = scale_factor[:, 0:1]
scale_x = scale_factor[:, 1:2]
scale_expand = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=1)
boxes_shape = paddle.shape(bboxes)
boxes_shape.stop_gradient = True
scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand)
if self.for_mot:
results = paddle.concat([bboxes, scores, clses], axis=1)
return results, inds
else:
results = paddle.concat([clses, scores, bboxes], axis=1)
return results, paddle.shape(results)[0:1]
......@@ -15,7 +15,9 @@
from . import jde_embedding_head
from . import pyramidal_embedding
from . import resnet
from . import fairmot_embedding_head
from .jde_embedding_head import *
from .pyramidal_embedding import *
from .resnet import *
from .fairmot_embedding_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import KaimingUniform, Uniform
from ppdet.core.workspace import register
from ppdet.modeling.heads.centernet_head import ConvLayer
__all__ = ['FairMOTEmbeddingHead']
@register
class FairMOTEmbeddingHead(nn.Layer):
"""
Args:
in_channels (int): the channel number of input to FairMOTEmbeddingHead.
ch_head (int): the channel of features before fed into embedding, 256 by default.
ch_emb (int): the channel of the embedding feature, 128 by default.
num_identifiers (int): the number of identifiers, 14455 by default.
"""
def __init__(self,
in_channels,
ch_head=256,
ch_emb=128,
num_identifiers=14455):
super(FairMOTEmbeddingHead, self).__init__()
self.reid = nn.Sequential(
ConvLayer(
in_channels, ch_head, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
ch_head, ch_emb, kernel_size=1, stride=1, padding=0, bias=True))
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
bound = 1 / math.sqrt(ch_emb)
bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
self.classifier = nn.Linear(
ch_emb,
num_identifiers,
weight_attr=param_attr,
bias_attr=bias_attr)
self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
# When num_identifiers is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log(
num_identifiers - 1) if num_identifiers > 1 else 1
@classmethod
def from_config(cls, cfg, input_shape):
if isinstance(input_shape, (list, tuple)):
input_shape = input_shape[0]
return {'in_channels': input_shape.channels}
def forward(self, feat, inputs):
reid_feat = self.reid(feat)
if self.training:
loss = self.get_loss(reid_feat, inputs)
return loss
else:
reid_feat = F.normalize(reid_feat)
return reid_feat
def get_loss(self, feat, inputs):
index = inputs['index']
mask = inputs['index_mask']
target = inputs['reid']
target = paddle.masked_select(target, mask > 0)
target = paddle.unsqueeze(target, 1)
feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
feat_n, feat_h, feat_w, feat_c = feat.shape
feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(feat_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
feat = paddle.gather_nd(feat, index=index)
mask = paddle.unsqueeze(mask, axis=2)
mask = paddle.expand_as(mask, feat)
mask.stop_gradient = True
feat = paddle.masked_select(feat, mask > 0)
feat = paddle.reshape(feat, shape=[-1, feat_c])
feat = F.normalize(feat)
feat = self.emb_scale * feat
logit = self.classifier(feat)
target.stop_gradient = True
loss = self.reid_loss(logit, target)
valid = (target != self.reid_loss.ignore_index)
valid.stop_gradient = True
count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
count.stop_gradient = True
if count > 0:
loss = loss / count
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册