未验证 提交 340da51b 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] add FairMOT-HarDNet85 (#4176)

* add hardnet85 fairmot

* fix hardnet85

* add centernet hardnet fpn, fix config

* update modelzoo

* update modelzoo readme

* remove comments

* add hardnet num_layers assert

* fix hardnet85 config
上级 13c1fef1
......@@ -136,6 +136,21 @@ If you use a stronger detection model, you can get better results. Each txt is t
FairMOT DLA-34 used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches.
### FairMOT enhance model
### Results on MOT-16 Test Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) |
### Results on MOT-17 Test Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) |
**注意:**
FairMOT enhance HarDNet-85 used 8 GPUs for training and mini-batch size as 10 on each GPU,and trained for 30 epoches. The crowdhuman dataset is added to the train-set during training.
### FairMOT light model
### Results on MOT-16 Test Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
......
......@@ -136,6 +136,21 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip
FairMOT DLA-34均使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。
### FairMOT enhance模型
### 在MOT-16 Test Set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) |
### 在MOT-17 Test Set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) |
**注意:**
FairMOT enhance HarDNet-85 使用8个GPU进行训练,每个GPU上batch size为10,训练30个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。
### FairMOT轻量级模型
### 在MOT-16 Test Set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
......
......@@ -36,6 +36,22 @@ English | [简体中文](README_cn.md)
**Notes:**
FairMOT DLA-34 used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches.
### FairMOT enhance model
### Results on MOT-16 Test Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot_enhance_hardnet85_30e_1088x608.yml) |
### Results on MOT-17 Test Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot_enhance_hardnet85_30e_1088x608.yml) |
**注意:**
FairMOT enhance HarDNet-85 used 8 GPUs for training and mini-batch size as 10 on each GPU,and trained for 30 epoches. The crowdhuman dataset is added to the train-set during training.
### FairMOT light model
### Results on MOT-16 Test Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
......
......@@ -36,6 +36,21 @@
FairMOT DLA-34均使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。
### FairMOT enhance模型
### 在MOT-16 Test Set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot_enhance_hardnet85_30e_1088x608.yml) |
### 在MOT-17 Test Set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot_enhance_hardnet85_30e_1088x608.yml) |
**注意:**
FairMOT enhance HarDNet-85 使用8个GPU进行训练,每个GPU上batch size为10,训练30个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。
### FairMOT轻量级模型
### 在MOT-16 Test Set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
......
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/centernet_hardnet85_coco.pdparams
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker
CenterNet:
backbone: HarDNet
neck: CenterNetHarDNetFPN
head: CenterNetHead
post_process: CenterNetPostProcess
for_mot: True
HarDNet:
depth_wise: False
return_idx: [1,3,8,13]
arch: 85
CenterNetHarDNetFPN:
num_layers: 85
down_ratio: 4
last_level: 4
out_channel: 0
CenterNetHead:
head_planes: 128
FairMOTEmbeddingHead:
ch_head: 512
CenterNetPostProcess:
for_mot: True
JDETracker:
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/optimizer_30e.yml',
'_base_/fairmot_enhance_hardnet85.yml',
'_base_/fairmot_reader_1088x608.yml',
]
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
worker_num: 4
TrainReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- RGBReverse: {}
- AugmentHSV: {}
- LetterBoxResize: {target_size: [608, 1088]}
- MOTRandomAffine: {reject_outside: False}
- RandomFlip: {}
- BboxXYXY2XYWH: {}
- NormalizeBox: {}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
- RGBReverse: {}
- Permute: {}
batch_transforms:
- Gt2FairMOTTarget: {}
batch_size: 10
shuffle: True
drop_last: True
use_shared_memory: True
epoch: 30
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [20,]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
weights: output/fairmot_enhance_hardnet85_30e_1088x608/model_final
......@@ -29,8 +29,8 @@ class CenterNet(BaseArch):
Args:
backbone (object): backbone instance
neck (object): FPN instance, default None, use 'CenterDLAFPN' in FairMOT
head (object): 'CenterHead' instance
neck (object): FPN instance, default use 'CenterNetDLAFPN'
head (object): 'CenterNetHead' instance
post_process (object): 'CenterNetPostProcess' instance
for_mot (bool): whether return other features used in tracking model
......@@ -40,8 +40,8 @@ class CenterNet(BaseArch):
def __init__(self,
backbone,
neck='CenterDLAFPN',
head='CenterHead',
neck='CenterNetDLAFPN',
head='CenterNetHead',
post_process='CenterNetPostProcess',
for_mot=False):
super(CenterNet, self).__init__()
......
......@@ -27,6 +27,7 @@ from . import dla
from . import shufflenet_v2
from . import swin_transformer
from . import lcnet
from . import hardnet
from .vgg import *
from .resnet import *
......@@ -43,3 +44,4 @@ from .dla import *
from .shufflenet_v2 import *
from .swin_transformer import *
from .lcnet import *
from .hardnet import *
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from ppdet.core.workspace import register
from ..shape_spec import ShapeSpec
__all__ = ['HarDNet']
def ConvLayer(in_channels,
out_channels,
kernel_size=3,
stride=1,
bias_attr=False):
layer = nn.Sequential(
('conv', nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=1,
bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)),
('relu', nn.ReLU6()))
return layer
def DWConvLayer(in_channels,
out_channels,
kernel_size=3,
stride=1,
bias_attr=False):
layer = nn.Sequential(
('dwconv', nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1,
groups=out_channels,
bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)))
return layer
def CombConvLayer(in_channels, out_channels, kernel_size=1, stride=1):
layer = nn.Sequential(
('layer1', ConvLayer(
in_channels, out_channels, kernel_size=kernel_size)),
('layer2', DWConvLayer(
out_channels, out_channels, stride=stride)))
return layer
class HarDBlock(nn.Layer):
def __init__(self,
in_channels,
growth_rate,
grmul,
n_layers,
keepBase=False,
residual_out=False,
dwconv=False):
super().__init__()
self.keepBase = keepBase
self.links = []
layers_ = []
self.out_channels = 0
for i in range(n_layers):
outch, inch, link = self.get_link(i + 1, in_channels, growth_rate,
grmul)
self.links.append(link)
if dwconv:
layers_.append(CombConvLayer(inch, outch))
else:
layers_.append(ConvLayer(inch, outch))
if (i % 2 == 0) or (i == n_layers - 1):
self.out_channels += outch
self.layers = nn.LayerList(layers_)
def get_out_ch(self):
return self.out_channels
def get_link(self, layer, base_ch, growth_rate, grmul):
if layer == 0:
return base_ch, 0, []
out_channels = growth_rate
link = []
for i in range(10):
dv = 2**i
if layer % dv == 0:
k = layer - dv
link.append(k)
if i > 0:
out_channels *= grmul
out_channels = int(int(out_channels + 1) / 2) * 2
in_channels = 0
for i in link:
ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
in_channels += ch
return out_channels, in_channels, link
def forward(self, x):
layers_ = [x]
for layer in range(len(self.layers)):
link = self.links[layer]
tin = []
for i in link:
tin.append(layers_[i])
if len(tin) > 1:
x = paddle.concat(tin, 1)
else:
x = tin[0]
out = self.layers[layer](x)
layers_.append(out)
t = len(layers_)
out_ = []
for i in range(t):
if (i == 0 and self.keepBase) or (i == t - 1) or (i % 2 == 1):
out_.append(layers_[i])
out = paddle.concat(out_, 1)
return out
@register
class HarDNet(nn.Layer):
def __init__(self, depth_wise=False, return_idx=[1, 3, 8, 13], arch=85):
super(HarDNet, self).__init__()
assert arch in [39, 68, 85], "HarDNet-{} not support.".format(arch)
if arch == 85:
first_ch = [48, 96]
second_kernel = 3
ch_list = [192, 256, 320, 480, 720]
grmul = 1.7
gr = [24, 24, 28, 36, 48]
n_layers = [8, 16, 16, 16, 16]
elif arch == 68:
first_ch = [32, 64]
second_kernel = 3
ch_list = [128, 256, 320, 640]
grmul = 1.7
gr = [14, 16, 20, 40]
n_layers = [8, 16, 16, 16]
self.return_idx = return_idx
self._out_channels = [96, 214, 458, 784]
avg_pool = True
if depth_wise:
second_kernel = 1
avg_pool = False
blks = len(n_layers)
self.base = nn.LayerList([])
# First Layer: Standard Conv3x3, Stride=2
self.base.append(
ConvLayer(
in_channels=3,
out_channels=first_ch[0],
kernel_size=3,
stride=2,
bias_attr=False))
# Second Layer
self.base.append(
ConvLayer(
first_ch[0], first_ch[1], kernel_size=second_kernel))
# Avgpooling or DWConv3x3 downsampling
if avg_pool:
self.base.append(nn.AvgPool2D(kernel_size=3, stride=2, padding=1))
else:
self.base.append(DWConvLayer(first_ch[1], first_ch[1], stride=2))
# Build all HarDNet blocks
ch = first_ch[1]
for i in range(blks):
blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise)
ch = blk.out_channels
self.base.append(blk)
if i != blks - 1:
self.base.append(ConvLayer(ch, ch_list[i], kernel_size=1))
ch = ch_list[i]
if i == 0:
self.base.append(
nn.AvgPool2D(
kernel_size=2, stride=2, ceil_mode=True))
elif i != blks - 1 and i != 1 and i != 3:
self.base.append(nn.AvgPool2D(kernel_size=2, stride=2))
def forward(self, inputs):
x = inputs['image']
outs = []
for i, layer in enumerate(self.base):
x = layer(x)
if i in self.return_idx:
outs.append(x)
return outs
@property
def out_shape(self):
return [ShapeSpec(channels=self._out_channels[i]) for i in range(4)]
......@@ -16,11 +16,15 @@ import numpy as np
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingUniform
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.backbones.hardnet import ConvLayer, HarDBlock
from ..shape_spec import ShapeSpec
__all__ = ['CenterNetDLAFPN', 'CenterNetHarDNetFPN']
def fill_up_weights(up):
weight = up.weight
......@@ -171,6 +175,7 @@ class CenterNetDLAFPN(nn.Layer):
return {'in_channels': [i.channels for i in input_shape]}
def forward(self, body_feats):
dla_up_feats = self.dla_up(body_feats)
ida_up_feats = []
......@@ -184,3 +189,140 @@ class CenterNetDLAFPN(nn.Layer):
@property
def out_shape(self):
return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]
class TransitionUp(nn.Layer):
def __init__(self, in_channels, out_channels):
super().__init__()
def forward(self, x, skip, concat=True):
w, h = skip.shape[2], skip.shape[3]
out = F.interpolate(x, size=(w, h), mode="bilinear", align_corners=True)
if concat:
out = paddle.concat([out, skip], 1)
return out
@register
@serializable
class CenterNetHarDNetFPN(nn.Layer):
"""
Args:
in_channels (list): number of input feature channels from backbone.
[96, 214, 458, 784] by default, means the channels of HarDNet85
num_layers (int): HarDNet laters, 85 by default
down_ratio (int): the down ratio from images to heatmap, 4 by default
first_level (int): the first level of input feature fed into the
upsamplng block
last_level (int): the last level of input feature fed into the upsamplng block
out_channel (int): the channel of the output feature, 0 by default means
the channel of the input feature whose down ratio is `down_ratio`
"""
def __init__(self,
in_channels,
num_layers=85,
down_ratio=4,
first_level=-1,
last_level=4,
out_channel=0):
super(CenterNetHarDNetFPN, self).__init__()
self.first_level = int(np.log2(
down_ratio)) - 1 if first_level == -1 else first_level
self.down_ratio = down_ratio
self.last_level = last_level
self.last_pool = nn.AvgPool2D(kernel_size=2, stride=2)
assert num_layers in [68, 85], "HarDNet-{} not support.".format(num_layers)
if num_layers == 85:
self.last_proj = ConvLayer(784, 256, kernel_size=1)
self.last_blk = HarDBlock(768, 80, 1.7, 8)
self.skip_nodes = [1, 3, 8, 13]
self.SC = [32, 32, 0]
gr = [64, 48, 28]
layers = [8, 8, 4]
ch_list2 = [224 + self.SC[0], 160 + self.SC[1], 96 + self.SC[2]]
channels = [96, 214, 458, 784]
self.skip_lv = 3
elif num_layers == 68:
self.last_proj = ConvLayer(654, 192, kernel_size=1)
self.last_blk = HarDBlock(576, 72, 1.7, 8)
self.skip_nodes = [1, 3, 8, 11]
self.SC = [32, 32, 0]
gr = [48, 32, 20]
layers = [8, 8, 4]
ch_list2 = [224 + self.SC[0], 96 + self.SC[1], 64 + self.SC[2]]
channels = [64, 124, 328, 654]
self.skip_lv = 2
self.transUpBlocks = nn.LayerList([])
self.denseBlocksUp = nn.LayerList([])
self.conv1x1_up = nn.LayerList([])
self.avg9x9 = nn.AvgPool2D(kernel_size=(9, 9), stride=1, padding=(4, 4))
prev_ch = self.last_blk.get_out_ch()
for i in range(3):
skip_ch = channels[3 - i]
self.transUpBlocks.append(TransitionUp(prev_ch, prev_ch))
if i < self.skip_lv:
cur_ch = prev_ch + skip_ch
else:
cur_ch = prev_ch
self.conv1x1_up.append(
ConvLayer(
cur_ch, ch_list2[i], kernel_size=1))
cur_ch = ch_list2[i]
cur_ch -= self.SC[i]
cur_ch *= 3
blk = HarDBlock(cur_ch, gr[i], 1.7, layers[i])
self.denseBlocksUp.append(blk)
prev_ch = blk.get_out_ch()
prev_ch += self.SC[0] + self.SC[1] + self.SC[2]
self.out_channel = prev_ch
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape]}
def forward(self, body_feats):
x = body_feats[-1]
x_sc = []
x = self.last_proj(x)
x = self.last_pool(x)
x2 = self.avg9x9(x)
x3 = x / (x.sum((2, 3), keepdim=True) + 0.1)
x = paddle.concat([x, x2, x3], 1)
x = self.last_blk(x)
for i in range(3):
skip_x = body_feats[3 - i]
x = self.transUpBlocks[i](x, skip_x, (i < self.skip_lv))
x = self.conv1x1_up[i](x)
if self.SC[i] > 0:
end = x.shape[1]
x_sc.append(x[:, end - self.SC[i]:, :, :])
x = x[:, :end - self.SC[i], :, :]
x2 = self.avg9x9(x)
x3 = x / (x.sum((2, 3), keepdim=True) + 0.1)
x = paddle.concat([x, x2, x3], 1)
x = self.denseBlocksUp[i](x)
scs = [x]
for i in range(3):
if self.SC[i] > 0:
scs.insert(
0,
F.interpolate(
x_sc[i],
size=(x.shape[2], x.shape[3]),
mode="bilinear",
align_corners=True))
neck_feat = paddle.concat(scs, 1)
return neck_feat
@property
def out_shape(self):
return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册