From 340da51be9227167f3673902533417bde19aae96 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Mon, 4 Oct 2021 23:21:51 +0800 Subject: [PATCH] [MOT] add FairMOT-HarDNet85 (#4176) * add hardnet85 fairmot * fix hardnet85 * add centernet hardnet fpn, fix config * update modelzoo * update modelzoo readme * remove comments * add hardnet num_layers assert * fix hardnet85 config --- configs/mot/README.md | 15 ++ configs/mot/README_cn.md | 15 ++ configs/mot/fairmot/README.md | 16 ++ configs/mot/fairmot/README_cn.md | 15 ++ .../_base_/fairmot_enhance_hardnet85.yml | 40 ++++ ...fairmot_enhance_hardnet85_30e_1088x608.yml | 49 ++++ ppdet/modeling/architectures/centernet.py | 8 +- ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/hardnet.py | 224 ++++++++++++++++++ ppdet/modeling/necks/centernet_fpn.py | 142 +++++++++++ 10 files changed, 522 insertions(+), 4 deletions(-) create mode 100644 configs/mot/fairmot/_base_/fairmot_enhance_hardnet85.yml create mode 100644 configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml create mode 100644 ppdet/modeling/backbones/hardnet.py diff --git a/configs/mot/README.md b/configs/mot/README.md index 709f864e5..459a1ac89 100644 --- a/configs/mot/README.md +++ b/configs/mot/README.md @@ -136,6 +136,21 @@ If you use a stronger detection model, you can get better results. Each txt is t FairMOT DLA-34 used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches. +### FairMOT enhance model +### Results on MOT-16 Test Set +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | + +### Results on MOT-17 Test Set +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | + +**注意:** + FairMOT enhance HarDNet-85 used 8 GPUs for training and mini-batch size as 10 on each GPU,and trained for 30 epoches. The crowdhuman dataset is added to the train-set during training. + + ### FairMOT light model ### Results on MOT-16 Test Set | backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | diff --git a/configs/mot/README_cn.md b/configs/mot/README_cn.md index f58c2a77f..05edf549d 100644 --- a/configs/mot/README_cn.md +++ b/configs/mot/README_cn.md @@ -136,6 +136,21 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip FairMOT DLA-34均使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 +### FairMOT enhance模型 +### 在MOT-16 Test Set上结果 +| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | + +### 在MOT-17 Test Set上结果 +| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml) | + +**注意:** + FairMOT enhance HarDNet-85 使用8个GPU进行训练,每个GPU上batch size为10,训练30个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。 + + ### FairMOT轻量级模型 ### 在MOT-16 Test Set上结果 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | diff --git a/configs/mot/fairmot/README.md b/configs/mot/fairmot/README.md index d542c87ac..cbef7a441 100644 --- a/configs/mot/fairmot/README.md +++ b/configs/mot/fairmot/README.md @@ -36,6 +36,22 @@ English | [简体中文](README_cn.md) **Notes:** FairMOT DLA-34 used 2 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches. + +### FairMOT enhance model +### Results on MOT-16 Test Set +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot_enhance_hardnet85_30e_1088x608.yml) | + +### Results on MOT-17 Test Set +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [config](./fairmot_enhance_hardnet85_30e_1088x608.yml) | + +**注意:** + FairMOT enhance HarDNet-85 used 8 GPUs for training and mini-batch size as 10 on each GPU,and trained for 30 epoches. The crowdhuman dataset is added to the train-set during training. + + ### FairMOT light model ### Results on MOT-16 Test Set | backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | diff --git a/configs/mot/fairmot/README_cn.md b/configs/mot/fairmot/README_cn.md index 8c9444c38..c7a5a856a 100644 --- a/configs/mot/fairmot/README_cn.md +++ b/configs/mot/fairmot/README_cn.md @@ -36,6 +36,21 @@ FairMOT DLA-34均使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 +### FairMOT enhance模型 +### 在MOT-16 Test Set上结果 +| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 75.0 | 70.0 | 1050 | 11837 | 32774 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot_enhance_hardnet85_30e_1088x608.yml) | + +### 在MOT-17 Test Set上结果 +| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: | +| HarDNet-85 | 1088x608 | 74.7 | 70.7 | 3210 | 29790 | 109914 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_enhance_hardnet85_30e_1088x608.pdparams) | [配置文件](./fairmot_enhance_hardnet85_30e_1088x608.yml) | + +**注意:** + FairMOT enhance HarDNet-85 使用8个GPU进行训练,每个GPU上batch size为10,训练30个epoch,并且训练集中加入了crowdhuman数据集一起参与训练。 + + ### FairMOT轻量级模型 ### 在MOT-16 Test Set上结果 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | diff --git a/configs/mot/fairmot/_base_/fairmot_enhance_hardnet85.yml b/configs/mot/fairmot/_base_/fairmot_enhance_hardnet85.yml new file mode 100644 index 000000000..58d88f3ac --- /dev/null +++ b/configs/mot/fairmot/_base_/fairmot_enhance_hardnet85.yml @@ -0,0 +1,40 @@ +architecture: FairMOT +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/centernet_hardnet85_coco.pdparams + +FairMOT: + detector: CenterNet + reid: FairMOTEmbeddingHead + loss: FairMOTLoss + tracker: JDETracker + +CenterNet: + backbone: HarDNet + neck: CenterNetHarDNetFPN + head: CenterNetHead + post_process: CenterNetPostProcess + for_mot: True + +HarDNet: + depth_wise: False + return_idx: [1,3,8,13] + arch: 85 + +CenterNetHarDNetFPN: + num_layers: 85 + down_ratio: 4 + last_level: 4 + out_channel: 0 + +CenterNetHead: + head_planes: 128 + +FairMOTEmbeddingHead: + ch_head: 512 + +CenterNetPostProcess: + for_mot: True + +JDETracker: + conf_thres: 0.4 + tracked_thresh: 0.4 + metric_type: cosine diff --git a/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml b/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml new file mode 100644 index 000000000..77936b3c3 --- /dev/null +++ b/configs/mot/fairmot/fairmot_enhance_hardnet85_30e_1088x608.yml @@ -0,0 +1,49 @@ +_BASE_: [ + '../../datasets/mot.yml', + '../../runtime.yml', + '_base_/optimizer_30e.yml', + '_base_/fairmot_enhance_hardnet85.yml', + '_base_/fairmot_reader_1088x608.yml', +] +norm_type: sync_bn +use_ema: true +ema_decay: 0.9998 + +worker_num: 4 +TrainReader: + inputs_def: + image_shape: [3, 608, 1088] + sample_transforms: + - Decode: {} + - RGBReverse: {} + - AugmentHSV: {} + - LetterBoxResize: {target_size: [608, 1088]} + - MOTRandomAffine: {reject_outside: False} + - RandomFlip: {} + - BboxXYXY2XYWH: {} + - NormalizeBox: {} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} + - RGBReverse: {} + - Permute: {} + batch_transforms: + - Gt2FairMOTTarget: {} + batch_size: 10 + shuffle: True + drop_last: True + use_shared_memory: True + +epoch: 30 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [20,] + use_warmup: False + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: NULL + +weights: output/fairmot_enhance_hardnet85_30e_1088x608/model_final diff --git a/ppdet/modeling/architectures/centernet.py b/ppdet/modeling/architectures/centernet.py index 080d5e886..434c01834 100755 --- a/ppdet/modeling/architectures/centernet.py +++ b/ppdet/modeling/architectures/centernet.py @@ -29,8 +29,8 @@ class CenterNet(BaseArch): Args: backbone (object): backbone instance - neck (object): FPN instance, default None, use 'CenterDLAFPN' in FairMOT - head (object): 'CenterHead' instance + neck (object): FPN instance, default use 'CenterNetDLAFPN' + head (object): 'CenterNetHead' instance post_process (object): 'CenterNetPostProcess' instance for_mot (bool): whether return other features used in tracking model @@ -40,8 +40,8 @@ class CenterNet(BaseArch): def __init__(self, backbone, - neck='CenterDLAFPN', - head='CenterHead', + neck='CenterNetDLAFPN', + head='CenterNetHead', post_process='CenterNetPostProcess', for_mot=False): super(CenterNet, self).__init__() diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index f57fd9025..138b64935 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -27,6 +27,7 @@ from . import dla from . import shufflenet_v2 from . import swin_transformer from . import lcnet +from . import hardnet from .vgg import * from .resnet import * @@ -43,3 +44,4 @@ from .dla import * from .shufflenet_v2 import * from .swin_transformer import * from .lcnet import * +from .hardnet import * diff --git a/ppdet/modeling/backbones/hardnet.py b/ppdet/modeling/backbones/hardnet.py new file mode 100644 index 000000000..14a1599df --- /dev/null +++ b/ppdet/modeling/backbones/hardnet.py @@ -0,0 +1,224 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from ppdet.core.workspace import register +from ..shape_spec import ShapeSpec + +__all__ = ['HarDNet'] + + +def ConvLayer(in_channels, + out_channels, + kernel_size=3, + stride=1, + bias_attr=False): + layer = nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=1, + bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)), + ('relu', nn.ReLU6())) + return layer + + +def DWConvLayer(in_channels, + out_channels, + kernel_size=3, + stride=1, + bias_attr=False): + layer = nn.Sequential( + ('dwconv', nn.Conv2D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=1, + groups=out_channels, + bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels))) + return layer + + +def CombConvLayer(in_channels, out_channels, kernel_size=1, stride=1): + layer = nn.Sequential( + ('layer1', ConvLayer( + in_channels, out_channels, kernel_size=kernel_size)), + ('layer2', DWConvLayer( + out_channels, out_channels, stride=stride))) + return layer + + +class HarDBlock(nn.Layer): + def __init__(self, + in_channels, + growth_rate, + grmul, + n_layers, + keepBase=False, + residual_out=False, + dwconv=False): + super().__init__() + self.keepBase = keepBase + self.links = [] + layers_ = [] + self.out_channels = 0 + for i in range(n_layers): + outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, + grmul) + self.links.append(link) + if dwconv: + layers_.append(CombConvLayer(inch, outch)) + else: + layers_.append(ConvLayer(inch, outch)) + + if (i % 2 == 0) or (i == n_layers - 1): + self.out_channels += outch + self.layers = nn.LayerList(layers_) + + def get_out_ch(self): + return self.out_channels + + def get_link(self, layer, base_ch, growth_rate, grmul): + if layer == 0: + return base_ch, 0, [] + out_channels = growth_rate + + link = [] + for i in range(10): + dv = 2**i + if layer % dv == 0: + k = layer - dv + link.append(k) + if i > 0: + out_channels *= grmul + + out_channels = int(int(out_channels + 1) / 2) * 2 + in_channels = 0 + + for i in link: + ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul) + in_channels += ch + + return out_channels, in_channels, link + + def forward(self, x): + layers_ = [x] + + for layer in range(len(self.layers)): + link = self.links[layer] + tin = [] + for i in link: + tin.append(layers_[i]) + if len(tin) > 1: + x = paddle.concat(tin, 1) + else: + x = tin[0] + out = self.layers[layer](x) + layers_.append(out) + + t = len(layers_) + out_ = [] + for i in range(t): + if (i == 0 and self.keepBase) or (i == t - 1) or (i % 2 == 1): + out_.append(layers_[i]) + out = paddle.concat(out_, 1) + + return out + + +@register +class HarDNet(nn.Layer): + def __init__(self, depth_wise=False, return_idx=[1, 3, 8, 13], arch=85): + super(HarDNet, self).__init__() + assert arch in [39, 68, 85], "HarDNet-{} not support.".format(arch) + if arch == 85: + first_ch = [48, 96] + second_kernel = 3 + ch_list = [192, 256, 320, 480, 720] + grmul = 1.7 + gr = [24, 24, 28, 36, 48] + n_layers = [8, 16, 16, 16, 16] + elif arch == 68: + first_ch = [32, 64] + second_kernel = 3 + ch_list = [128, 256, 320, 640] + grmul = 1.7 + gr = [14, 16, 20, 40] + n_layers = [8, 16, 16, 16] + + self.return_idx = return_idx + self._out_channels = [96, 214, 458, 784] + + avg_pool = True + if depth_wise: + second_kernel = 1 + avg_pool = False + + blks = len(n_layers) + self.base = nn.LayerList([]) + + # First Layer: Standard Conv3x3, Stride=2 + self.base.append( + ConvLayer( + in_channels=3, + out_channels=first_ch[0], + kernel_size=3, + stride=2, + bias_attr=False)) + + # Second Layer + self.base.append( + ConvLayer( + first_ch[0], first_ch[1], kernel_size=second_kernel)) + + # Avgpooling or DWConv3x3 downsampling + if avg_pool: + self.base.append(nn.AvgPool2D(kernel_size=3, stride=2, padding=1)) + else: + self.base.append(DWConvLayer(first_ch[1], first_ch[1], stride=2)) + + # Build all HarDNet blocks + ch = first_ch[1] + for i in range(blks): + blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise) + ch = blk.out_channels + self.base.append(blk) + + if i != blks - 1: + self.base.append(ConvLayer(ch, ch_list[i], kernel_size=1)) + ch = ch_list[i] + if i == 0: + self.base.append( + nn.AvgPool2D( + kernel_size=2, stride=2, ceil_mode=True)) + elif i != blks - 1 and i != 1 and i != 3: + self.base.append(nn.AvgPool2D(kernel_size=2, stride=2)) + + def forward(self, inputs): + x = inputs['image'] + outs = [] + for i, layer in enumerate(self.base): + x = layer(x) + if i in self.return_idx: + outs.append(x) + return outs + + @property + def out_shape(self): + return [ShapeSpec(channels=self._out_channels[i]) for i in range(4)] diff --git a/ppdet/modeling/necks/centernet_fpn.py b/ppdet/modeling/necks/centernet_fpn.py index 262959051..a25c270be 100755 --- a/ppdet/modeling/necks/centernet_fpn.py +++ b/ppdet/modeling/necks/centernet_fpn.py @@ -16,11 +16,15 @@ import numpy as np import math import paddle import paddle.nn as nn +import paddle.nn.functional as F from paddle.nn.initializer import KaimingUniform from ppdet.core.workspace import register, serializable from ppdet.modeling.layers import ConvNormLayer +from ppdet.modeling.backbones.hardnet import ConvLayer, HarDBlock from ..shape_spec import ShapeSpec +__all__ = ['CenterNetDLAFPN', 'CenterNetHarDNetFPN'] + def fill_up_weights(up): weight = up.weight @@ -171,6 +175,7 @@ class CenterNetDLAFPN(nn.Layer): return {'in_channels': [i.channels for i in input_shape]} def forward(self, body_feats): + dla_up_feats = self.dla_up(body_feats) ida_up_feats = [] @@ -184,3 +189,140 @@ class CenterNetDLAFPN(nn.Layer): @property def out_shape(self): return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)] + + +class TransitionUp(nn.Layer): + def __init__(self, in_channels, out_channels): + super().__init__() + + def forward(self, x, skip, concat=True): + w, h = skip.shape[2], skip.shape[3] + out = F.interpolate(x, size=(w, h), mode="bilinear", align_corners=True) + if concat: + out = paddle.concat([out, skip], 1) + return out + + +@register +@serializable +class CenterNetHarDNetFPN(nn.Layer): + """ + Args: + in_channels (list): number of input feature channels from backbone. + [96, 214, 458, 784] by default, means the channels of HarDNet85 + num_layers (int): HarDNet laters, 85 by default + down_ratio (int): the down ratio from images to heatmap, 4 by default + first_level (int): the first level of input feature fed into the + upsamplng block + last_level (int): the last level of input feature fed into the upsamplng block + out_channel (int): the channel of the output feature, 0 by default means + the channel of the input feature whose down ratio is `down_ratio` + """ + + def __init__(self, + in_channels, + num_layers=85, + down_ratio=4, + first_level=-1, + last_level=4, + out_channel=0): + super(CenterNetHarDNetFPN, self).__init__() + self.first_level = int(np.log2( + down_ratio)) - 1 if first_level == -1 else first_level + self.down_ratio = down_ratio + self.last_level = last_level + self.last_pool = nn.AvgPool2D(kernel_size=2, stride=2) + + assert num_layers in [68, 85], "HarDNet-{} not support.".format(num_layers) + if num_layers == 85: + self.last_proj = ConvLayer(784, 256, kernel_size=1) + self.last_blk = HarDBlock(768, 80, 1.7, 8) + self.skip_nodes = [1, 3, 8, 13] + self.SC = [32, 32, 0] + gr = [64, 48, 28] + layers = [8, 8, 4] + ch_list2 = [224 + self.SC[0], 160 + self.SC[1], 96 + self.SC[2]] + channels = [96, 214, 458, 784] + self.skip_lv = 3 + + elif num_layers == 68: + self.last_proj = ConvLayer(654, 192, kernel_size=1) + self.last_blk = HarDBlock(576, 72, 1.7, 8) + self.skip_nodes = [1, 3, 8, 11] + self.SC = [32, 32, 0] + gr = [48, 32, 20] + layers = [8, 8, 4] + ch_list2 = [224 + self.SC[0], 96 + self.SC[1], 64 + self.SC[2]] + channels = [64, 124, 328, 654] + self.skip_lv = 2 + + self.transUpBlocks = nn.LayerList([]) + self.denseBlocksUp = nn.LayerList([]) + self.conv1x1_up = nn.LayerList([]) + self.avg9x9 = nn.AvgPool2D(kernel_size=(9, 9), stride=1, padding=(4, 4)) + prev_ch = self.last_blk.get_out_ch() + + for i in range(3): + skip_ch = channels[3 - i] + self.transUpBlocks.append(TransitionUp(prev_ch, prev_ch)) + if i < self.skip_lv: + cur_ch = prev_ch + skip_ch + else: + cur_ch = prev_ch + self.conv1x1_up.append( + ConvLayer( + cur_ch, ch_list2[i], kernel_size=1)) + cur_ch = ch_list2[i] + cur_ch -= self.SC[i] + cur_ch *= 3 + + blk = HarDBlock(cur_ch, gr[i], 1.7, layers[i]) + self.denseBlocksUp.append(blk) + prev_ch = blk.get_out_ch() + + prev_ch += self.SC[0] + self.SC[1] + self.SC[2] + self.out_channel = prev_ch + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape]} + + def forward(self, body_feats): + x = body_feats[-1] + x_sc = [] + x = self.last_proj(x) + x = self.last_pool(x) + x2 = self.avg9x9(x) + x3 = x / (x.sum((2, 3), keepdim=True) + 0.1) + x = paddle.concat([x, x2, x3], 1) + x = self.last_blk(x) + + for i in range(3): + skip_x = body_feats[3 - i] + x = self.transUpBlocks[i](x, skip_x, (i < self.skip_lv)) + x = self.conv1x1_up[i](x) + if self.SC[i] > 0: + end = x.shape[1] + x_sc.append(x[:, end - self.SC[i]:, :, :]) + x = x[:, :end - self.SC[i], :, :] + x2 = self.avg9x9(x) + x3 = x / (x.sum((2, 3), keepdim=True) + 0.1) + x = paddle.concat([x, x2, x3], 1) + x = self.denseBlocksUp[i](x) + + scs = [x] + for i in range(3): + if self.SC[i] > 0: + scs.insert( + 0, + F.interpolate( + x_sc[i], + size=(x.shape[2], x.shape[3]), + mode="bilinear", + align_corners=True)) + neck_feat = paddle.concat(scs, 1) + return neck_feat + + @property + def out_shape(self): + return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)] -- GitLab