From b1186bc8227a94ddc696ad4620ef7c6f37685083 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 6 Sep 2021 13:45:43 +0800 Subject: [PATCH] update picodet model zoo (#4117) --- configs/picodet/README.md | 4 +- ...{optimizer_280e.yml => optimizer_300e.yml} | 6 +- configs/picodet/_base_/picodet_320_reader.yml | 13 +- configs/picodet/_base_/picodet_416_reader.yml | 13 +- configs/picodet/picodet_l_r18_320_coco.yml | 4 +- configs/picodet/picodet_m_mbv3_320_coco.yml | 2 +- configs/picodet/picodet_m_mbv3_416_coco.yml | 2 +- .../picodet_m_shufflenetv2_320_coco.yml | 2 +- .../picodet_m_shufflenetv2_416_coco.yml | 2 +- configs/picodet/picodet_s_lcnet_320_coco.yml | 23 ++ configs/picodet/picodet_s_lcnet_416_coco.yml | 23 ++ .../picodet_s_shufflenetv2_320_coco.yml | 2 +- .../picodet_s_shufflenetv2_416_coco.yml | 2 +- configs/picodet/picodet_xs_lcnet_320_coco.yml | 23 ++ ppdet/data/transform/atss_assigner.py | 6 +- ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/lcnet.py | 258 ++++++++++++++++++ ppdet/modeling/backbones/shufflenet_v2.py | 30 +- ppdet/utils/checkpoint.py | 7 + 19 files changed, 371 insertions(+), 53 deletions(-) rename configs/picodet/_base_/{optimizer_280e.yml => optimizer_300e.yml} (80%) create mode 100644 configs/picodet/picodet_s_lcnet_320_coco.yml create mode 100644 configs/picodet/picodet_s_lcnet_416_coco.yml create mode 100644 configs/picodet/picodet_xs_lcnet_320_coco.yml create mode 100644 ppdet/modeling/backbones/lcnet.py diff --git a/configs/picodet/README.md b/configs/picodet/README.md index 601eef6e4..85cd197a0 100644 --- a/configs/picodet/README.md +++ b/configs/picodet/README.md @@ -28,8 +28,8 @@ Optimizing method of we use: | Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config | | :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: | -| ShuffleNetv2-1x | 320*320 | 280e | 22.3 | 36.8 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) | -| ShuffleNetv2-1x | 416*416 | 280e | 24.6 | 44.3 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) | +| ShuffleNetv2-1x | 320*320 | 280e | 22.8 | 37.7 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) | +| ShuffleNetv2-1x | 416*416 | 280e | 25.3 | 41.1 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) | ### PicoDet-M diff --git a/configs/picodet/_base_/optimizer_280e.yml b/configs/picodet/_base_/optimizer_300e.yml similarity index 80% rename from configs/picodet/_base_/optimizer_280e.yml rename to configs/picodet/_base_/optimizer_300e.yml index a72bc7d24..5a89bbbce 100644 --- a/configs/picodet/_base_/optimizer_280e.yml +++ b/configs/picodet/_base_/optimizer_300e.yml @@ -1,10 +1,10 @@ -epoch: 280 +epoch: 300 LearningRate: base_lr: 0.4 schedulers: - !CosineDecay - max_epochs: 280 + max_epochs: 300 - !LinearWarmup start_factor: 0.1 steps: 300 @@ -14,5 +14,5 @@ OptimizerBuilder: momentum: 0.9 type: Momentum regularizer: - factor: 0.0001 + factor: 0.00004 type: L2 diff --git a/configs/picodet/_base_/picodet_320_reader.yml b/configs/picodet/_base_/picodet_320_reader.yml index 99b504ece..469184529 100644 --- a/configs/picodet/_base_/picodet_320_reader.yml +++ b/configs/picodet/_base_/picodet_320_reader.yml @@ -1,15 +1,14 @@ -worker_num: 6 +worker_num: 8 TrainReader: sample_transforms: - Decode: {} - RandomCrop: {} - RandomFlip: {prob: 0.5} - - Resize: {target_size: [320, 320], keep_ratio: False, interp: 1} - - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - RandomDistort: {} - - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} + - BatchRandomResize: {target_size: [256, 288, 320, 352, 384], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} - Gt2GFLTarget: downsample_ratios: [8, 16, 32] grid_cell_scale: 5 @@ -22,7 +21,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 1, target_size: [320, 320], keep_ratio: False} + - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -36,7 +35,7 @@ TestReader: image_shape: [3, 320, 320] sample_transforms: - Decode: {} - - Resize: {interp: 1, target_size: [320, 320], keep_ratio: False} + - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: diff --git a/configs/picodet/_base_/picodet_416_reader.yml b/configs/picodet/_base_/picodet_416_reader.yml index b5e950964..58b6607dc 100644 --- a/configs/picodet/_base_/picodet_416_reader.yml +++ b/configs/picodet/_base_/picodet_416_reader.yml @@ -4,17 +4,16 @@ TrainReader: - Decode: {} - RandomCrop: {} - RandomFlip: {prob: 0.5} - - Resize: {target_size: [416, 416], keep_ratio: False, interp: 1} - - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - RandomDistort: {} - - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} + - BatchRandomResize: {target_size: [352, 384, 416, 448, 480], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} - Gt2GFLTarget: downsample_ratios: [8, 16, 32] grid_cell_scale: 5 cell_offset: 0.5 - batch_size: 96 + batch_size: 80 shuffle: true drop_last: true @@ -22,7 +21,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 1, target_size: [416, 416], keep_ratio: False} + - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -36,7 +35,7 @@ TestReader: image_shape: [3, 416, 416] sample_transforms: - Decode: {} - - Resize: {interp: 1, target_size: [416, 416], keep_ratio: False} + - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: diff --git a/configs/picodet/picodet_l_r18_320_coco.yml b/configs/picodet/picodet_l_r18_320_coco.yml index 9499e2812..d627a975b 100644 --- a/configs/picodet/picodet_l_r18_320_coco.yml +++ b/configs/picodet/picodet_l_r18_320_coco.yml @@ -2,11 +2,11 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_mbv3_0_5x.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] -weights: output/picodet_m_r18_320_coco/model_final +weights: output/picodet_l_r18_320_coco/model_final pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams find_unused_parameters: True use_ema: true diff --git a/configs/picodet/picodet_m_mbv3_320_coco.yml b/configs/picodet/picodet_m_mbv3_320_coco.yml index 1755b49bd..9e4055b8a 100644 --- a/configs/picodet/picodet_m_mbv3_320_coco.yml +++ b/configs/picodet/picodet_m_mbv3_320_coco.yml @@ -2,7 +2,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_mobilenetv3.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] diff --git a/configs/picodet/picodet_m_mbv3_416_coco.yml b/configs/picodet/picodet_m_mbv3_416_coco.yml index e07d5300b..f2e9653dd 100644 --- a/configs/picodet/picodet_m_mbv3_416_coco.yml +++ b/configs/picodet/picodet_m_mbv3_416_coco.yml @@ -2,7 +2,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_mobilenetv3.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_416_reader.yml', ] diff --git a/configs/picodet/picodet_m_shufflenetv2_320_coco.yml b/configs/picodet/picodet_m_shufflenetv2_320_coco.yml index 6b30d8660..168b36dfc 100644 --- a/configs/picodet/picodet_m_shufflenetv2_320_coco.yml +++ b/configs/picodet/picodet_m_shufflenetv2_320_coco.yml @@ -2,7 +2,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] diff --git a/configs/picodet/picodet_m_shufflenetv2_416_coco.yml b/configs/picodet/picodet_m_shufflenetv2_416_coco.yml index c2af99d76..0726ab8e2 100644 --- a/configs/picodet/picodet_m_shufflenetv2_416_coco.yml +++ b/configs/picodet/picodet_m_shufflenetv2_416_coco.yml @@ -2,7 +2,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_416_reader.yml', ] diff --git a/configs/picodet/picodet_s_lcnet_320_coco.yml b/configs/picodet/picodet_s_lcnet_320_coco.yml new file mode 100644 index 000000000..762ae1d90 --- /dev/null +++ b/configs/picodet/picodet_s_lcnet_320_coco.yml @@ -0,0 +1,23 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_300e.yml', + '_base_/picodet_320_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams +weights: output/picodet_s_lcnet_320_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: LCNet + neck: PAN + head: PicoHead + +LCNet: + scale: 1.0 + feature_maps: [3, 4, 5] diff --git a/configs/picodet/picodet_s_lcnet_416_coco.yml b/configs/picodet/picodet_s_lcnet_416_coco.yml new file mode 100644 index 000000000..f638b2a48 --- /dev/null +++ b/configs/picodet/picodet_s_lcnet_416_coco.yml @@ -0,0 +1,23 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_300e.yml', + '_base_/picodet_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams +weights: output/picodet_s_lcnet_416_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: LCNet + neck: PAN + head: PicoHead + +LCNet: + scale: 1.0 + feature_maps: [3, 4, 5] diff --git a/configs/picodet/picodet_s_shufflenetv2_320_coco.yml b/configs/picodet/picodet_s_shufflenetv2_320_coco.yml index 2ac87e15d..009e994aa 100644 --- a/configs/picodet/picodet_s_shufflenetv2_320_coco.yml +++ b/configs/picodet/picodet_s_shufflenetv2_320_coco.yml @@ -2,7 +2,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] diff --git a/configs/picodet/picodet_s_shufflenetv2_416_coco.yml b/configs/picodet/picodet_s_shufflenetv2_416_coco.yml index 94871143b..9c551f7a6 100644 --- a/configs/picodet/picodet_s_shufflenetv2_416_coco.yml +++ b/configs/picodet/picodet_s_shufflenetv2_416_coco.yml @@ -2,7 +2,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_280e.yml', + '_base_/optimizer_300e.yml', '_base_/picodet_416_reader.yml', ] diff --git a/configs/picodet/picodet_xs_lcnet_320_coco.yml b/configs/picodet/picodet_xs_lcnet_320_coco.yml new file mode 100644 index 000000000..ab286963d --- /dev/null +++ b/configs/picodet/picodet_xs_lcnet_320_coco.yml @@ -0,0 +1,23 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_320_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x0_25_pretrained.pdparams +weights: output/picodet_s_shufflenetv2_320_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: LCNet + neck: PAN + head: PicoHead + +LCNet: + scale: 0.25 + feature_maps: [3, 4, 5] diff --git a/ppdet/data/transform/atss_assigner.py b/ppdet/data/transform/atss_assigner.py index 967e39889..d41c85a7e 100644 --- a/ppdet/data/transform/atss_assigner.py +++ b/ppdet/data/transform/atss_assigner.py @@ -178,8 +178,6 @@ class ATSSAssigner(object): """ bboxes = bboxes[:, :4] num_gt, num_bboxes = gt_bboxes.shape[0], bboxes.shape[0] - # compute iou between all bbox and gt - overlaps = bbox_overlaps(bboxes, gt_bboxes) # assign 0 by default assigned_gt_inds = np.zeros((num_bboxes, ), dtype=np.int64) @@ -194,8 +192,10 @@ class ATSSAssigner(object): assigned_labels = None else: assigned_labels = -np.ones((num_bboxes, ), dtype=np.int64) - return assigned_gt_inds, max_overlaps, assigned_labels + return assigned_gt_inds, max_overlaps + # compute iou between all bbox and gt + overlaps = bbox_overlaps(bboxes, gt_bboxes) # compute center distance between all bbox and gt gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 4b4ce27f8..f57fd9025 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -26,6 +26,7 @@ from . import res2net from . import dla from . import shufflenet_v2 from . import swin_transformer +from . import lcnet from .vgg import * from .resnet import * @@ -41,3 +42,4 @@ from .res2net import * from .dla import * from .shufflenet_v2 import * from .swin_transformer import * +from .lcnet import * diff --git a/ppdet/modeling/backbones/lcnet.py b/ppdet/modeling/backbones/lcnet.py new file mode 100644 index 000000000..fd8ad4e46 --- /dev/null +++ b/ppdet/modeling/backbones/lcnet.py @@ -0,0 +1,258 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal + +from ppdet.core.workspace import register, serializable +from numbers import Integral +from ..shape_spec import ShapeSpec + +__all__ = ['LCNet'] + +NET_CONFIG = { + "blocks2": + #k, in_c, out_c, s, use_se + [[3, 16, 32, 1, False], ], + "blocks3": [ + [3, 32, 64, 2, False], + [3, 64, 64, 1, False], + ], + "blocks4": [ + [3, 64, 128, 2, False], + [3, 128, 128, 1, False], + ], + "blocks5": [ + [3, 128, 256, 2, False], + [5, 256, 256, 1, False], + [5, 256, 256, 1, False], + [5, 256, 256, 1, False], + [5, 256, 256, 1, False], + [5, 256, 256, 1, False], + ], + "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]] +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + num_groups=1): + super().__init__() + + self.conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=num_groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = BatchNorm( + num_filters, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.hardswish = nn.Hardswish() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.hardswish(x) + return x + + +class DepthwiseSeparable(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + dw_size=3, + use_se=False): + super().__init__() + self.use_se = use_se + self.dw_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_channels, + filter_size=dw_size, + stride=stride, + num_groups=num_channels) + if use_se: + self.se = SEModule(num_channels) + self.pw_conv = ConvBNLayer( + num_channels=num_channels, + filter_size=1, + num_filters=num_filters, + stride=1) + + def forward(self, x): + x = self.dw_conv(x) + if self.use_se: + x = self.se(x) + x = self.pw_conv(x) + return x + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = nn.Hardsigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +@register +@serializable +class LCNet(nn.Layer): + def __init__(self, scale=1.0, feature_maps=[3, 4, 5]): + super().__init__() + self.scale = scale + self.feature_maps = feature_maps + + out_channels = [] + + self.conv1 = ConvBNLayer( + num_channels=3, + filter_size=3, + num_filters=make_divisible(16 * scale), + stride=2) + + self.blocks2 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"]) + ]) + + self.blocks3 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"]) + ]) + + out_channels.append( + make_divisible(NET_CONFIG["blocks3"][-1][2] * scale)) + + self.blocks4 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"]) + ]) + + out_channels.append( + make_divisible(NET_CONFIG["blocks4"][-1][2] * scale)) + + self.blocks5 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"]) + ]) + + out_channels.append( + make_divisible(NET_CONFIG["blocks5"][-1][2] * scale)) + + self.blocks6 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"]) + ]) + + out_channels.append( + make_divisible(NET_CONFIG["blocks6"][-1][2] * scale)) + self._out_channels = [ + ch for idx, ch in enumerate(out_channels) if idx + 2 in feature_maps + ] + + def forward(self, inputs): + x = inputs['image'] + outs = [] + + x = self.conv1(x) + x = self.blocks2(x) + x = self.blocks3(x) + outs.append(x) + x = self.blocks4(x) + outs.append(x) + x = self.blocks5(x) + outs.append(x) + x = self.blocks6(x) + outs.append(x) + outs = [o for i, o in enumerate(outs) if i + 2 in self.feature_maps] + return outs + + @property + def out_shape(self): + return [ShapeSpec(channels=c) for c in self._out_channels] diff --git a/ppdet/modeling/backbones/shufflenet_v2.py b/ppdet/modeling/backbones/shufflenet_v2.py index cd7ea2d9b..59b0502a1 100644 --- a/ppdet/modeling/backbones/shufflenet_v2.py +++ b/ppdet/modeling/backbones/shufflenet_v2.py @@ -21,6 +21,7 @@ import paddle.nn as nn from paddle import ParamAttr from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm from paddle.nn.initializer import KaimingNormal +from paddle.regularizer import L2Decay from ppdet.core.workspace import register, serializable from numbers import Integral @@ -50,7 +51,11 @@ class ConvBNLayer(nn.Layer): weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False) - self._batch_norm = BatchNorm(out_channels, act=act) + self._batch_norm = BatchNorm( + out_channels, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + act=act) def forward(self, inputs): y = self._conv(inputs) @@ -159,14 +164,9 @@ class InvertedResidualDS(nn.Layer): @register @serializable class ShuffleNetV2(nn.Layer): - def __init__(self, - scale=1.0, - act="relu", - feature_maps=[5, 13, 17], - with_last_conv=False): + def __init__(self, scale=1.0, act="relu", feature_maps=[5, 13, 17]): super(ShuffleNetV2, self).__init__() self.scale = scale - self.with_last_conv = with_last_conv if isinstance(feature_maps, Integral): feature_maps = [feature_maps] self.feature_maps = feature_maps @@ -226,19 +226,6 @@ class ShuffleNetV2(nn.Layer): self._update_out_channels(stage_out_channels[stage_id + 2], self._feature_idx, self.feature_maps) - if self.with_last_conv: - # last_conv - self._last_conv = ConvBNLayer( - in_channels=stage_out_channels[-2], - out_channels=stage_out_channels[-1], - kernel_size=1, - stride=1, - padding=0, - act=act) - self._feature_idx += 1 - self._update_out_channels(stage_out_channels[-1], self._feature_idx, - self.feature_maps) - def _update_out_channels(self, channel, feature_idx, feature_maps): if feature_idx in feature_maps: self._out_channels.append(channel) @@ -252,9 +239,6 @@ class ShuffleNetV2(nn.Layer): if i + 2 in self.feature_maps: outs.append(y) - if self.with_last_conv: - y = self._last_conv(y) - outs.append(y) return outs @property diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 28b4608ac..b5aa84697 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -139,6 +139,13 @@ def match_state_dict(model_state_dict, weight_state_dict): max_id = match_matrix.argmax(1) max_len = match_matrix.max(1) max_id[max_len == 0] = -1 + not_load_weight_name = [] + for match_idx in range(len(max_id)): + if match_idx < len(weight_keys) and max_id[match_idx] == -1: + not_load_weight_name.append(weight_keys[match_idx]) + if len(not_load_weight_name) > 0: + logger.info('{} in pretrained weight is not used in the model, ' + 'and its will not be loaded'.format(not_load_weight_name)) matched_keys = {} result_state_dict = {} for model_id, weight_id in enumerate(max_id): -- GitLab