diff --git a/configs/GhostNet/GhostNet_x0_5.yaml b/configs/GhostNet/GhostNet_x0_5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e787c6bd381fe874029d53600ac2bd9f4c68f028 --- /dev/null +++ b/configs/GhostNet/GhostNet_x0_5.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'GhostNet_x0_5' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'CosineWarmup' + params: + lr: 0.8 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.0000400 + +TRAIN: + batch_size: 2048 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/GhostNet/GhostNet_x1_0.yaml b/configs/GhostNet/GhostNet_x1_0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1a61dca4c6e820649fc2c764688d88d0d36cf30 --- /dev/null +++ b/configs/GhostNet/GhostNet_x1_0.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'GhostNet_x1_0' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'CosineWarmup' + params: + lr: 0.4 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.0000400 + +TRAIN: + batch_size: 1024 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/GhostNet/GhostNet_x1_3.yaml b/configs/GhostNet/GhostNet_x1_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..452c1f49804d39264b4efa1c90cce280a0de3ef2 --- /dev/null +++ b/configs/GhostNet/GhostNet_x1_3.yaml @@ -0,0 +1,75 @@ +mode: 'train' +ARCHITECTURE: + name: 'GhostNet_x1_3' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'CosineWarmup' + params: + lr: 0.4 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.0000400 + +TRAIN: + batch_size: 1024 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AutoAugment: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/ShuffleNet/ShuffleNetV2.yaml b/configs/ShuffleNet/ShuffleNetV2.yaml index 1ee8787c0fd504b23c858416d643b4724933a6f5..c097afaea5a97bd02cdc5d3eef236a71a3feb4b7 100644 --- a/configs/ShuffleNet/ShuffleNetV2.yaml +++ b/configs/ShuffleNet/ShuffleNetV2.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.5 warmup_epoch: 5 diff --git a/configs/ShuffleNet/ShuffleNetV2_swish.yaml b/configs/ShuffleNet/ShuffleNetV2_swish.yaml index 313f626fcd83491151a9496183290889e6e8e7dd..4e64ce8bb0562258c70234a2a7f888b0dbe08f8e 100644 --- a/configs/ShuffleNet/ShuffleNetV2_swish.yaml +++ b/configs/ShuffleNet/ShuffleNetV2_swish.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.5 warmup_epoch: 5 diff --git a/configs/ShuffleNet/ShuffleNetV2_x0_25.yaml b/configs/ShuffleNet/ShuffleNetV2_x0_25.yaml index a8e8055e67da82e76fc9e46d30bcb075d832d66a..996f040bbd3fac8bd7d8bce672379cda647f0714 100644 --- a/configs/ShuffleNet/ShuffleNetV2_x0_25.yaml +++ b/configs/ShuffleNet/ShuffleNetV2_x0_25.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.5 warmup_epoch: 5 diff --git a/configs/ShuffleNet/ShuffleNetV2_x0_33.yaml b/configs/ShuffleNet/ShuffleNetV2_x0_33.yaml index 9e1814013e99f9b1cc3236e4bf10c53386ec7321..f2941474105a4352ef1294711ecbb865a19b4774 100644 --- a/configs/ShuffleNet/ShuffleNetV2_x0_33.yaml +++ b/configs/ShuffleNet/ShuffleNetV2_x0_33.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.5 warmup_epoch: 5 diff --git a/configs/ShuffleNet/ShuffleNetV2_x0_5.yaml b/configs/ShuffleNet/ShuffleNetV2_x0_5.yaml index be8f0be06396257490d4def0d906244a4150b0bd..05a1ad3eb94cb71e16e15eefb1bd195dee4b143b 100644 --- a/configs/ShuffleNet/ShuffleNetV2_x0_5.yaml +++ b/configs/ShuffleNet/ShuffleNetV2_x0_5.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.5 warmup_epoch: 5 diff --git a/configs/ShuffleNet/ShuffleNetV2_x1_5.yaml b/configs/ShuffleNet/ShuffleNetV2_x1_5.yaml index a10ec37ca7af1afa830af25d3aa096a03f1febd1..63f50d48404e1bb0b9b9c599418df19465087a99 100644 --- a/configs/ShuffleNet/ShuffleNetV2_x1_5.yaml +++ b/configs/ShuffleNet/ShuffleNetV2_x1_5.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.25 warmup_epoch: 5 diff --git a/configs/ShuffleNet/ShuffleNetV2_x2_0.yaml b/configs/ShuffleNet/ShuffleNetV2_x2_0.yaml index d84e29bc7075e73e44eb457cf32497c5b04671e2..5a14cebbf8e789a9e5c03d180f11bb0dc5a42e47 100644 --- a/configs/ShuffleNet/ShuffleNetV2_x2_0.yaml +++ b/configs/ShuffleNet/ShuffleNetV2_x2_0.yaml @@ -14,7 +14,7 @@ topk: 5 image_shape: [3, 224, 224] LEARNING_RATE: - function: 'Cosine' + function: 'CosineWarmup' params: lr: 0.25 warmup_epoch: 5 diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py index c9eeb5a6a2f47d9bf8fdc8484807ac48efc31f79..88f886e9661ff22fc10b4d5053b454f25b95b04e 100644 --- a/ppcls/modeling/architectures/__init__.py +++ b/ppcls/modeling/architectures/__init__.py @@ -28,6 +28,7 @@ from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44 from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50 from .googlenet import GoogLeNet +from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3 from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1 from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0 from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 diff --git a/ppcls/modeling/architectures/ghostnet.py b/ppcls/modeling/architectures/ghostnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b37f2230f76873553841b4d12da01b50f2477599 --- /dev/null +++ b/ppcls/modeling/architectures/ghostnet.py @@ -0,0 +1,335 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2d, BatchNorm, AdaptiveAvgPool2d, Linear +from paddle.fluid.regularizer import L2DecayRegularizer +from paddle.nn.initializer import Uniform + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + act="relu", + name=None): + super(ConvBNLayer, self).__init__() + self._conv = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr( + initializer=nn.initializer.MSRA(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + + # In the old version, moving_variance_name was name + "_variance" + self._batch_norm = BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr( + name=bn_name + "_scale", + regularizer=L2DecayRegularizer(regularization_coeff=0.0)), + bias_attr=ParamAttr( + name=bn_name + "_offset", + regularizer=L2DecayRegularizer(regularization_coeff=0.0)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=name + + "_variance" # wrong due to an old typo, will be fixed later. + ) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class SEBlock(nn.Layer): + def __init__(self, num_channels, reduction_ratio=4, name=None): + super(SEBlock, self).__init__() + self.pool2d_gap = AdaptiveAvgPool2d(1) + self._num_channels = num_channels + stdv = 1.0 / math.sqrt(num_channels * 1.0) + med_ch = num_channels // reduction_ratio + self.squeeze = Linear( + num_channels, + med_ch, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name=name + "_1_weights"), + bias_attr=ParamAttr(name=name + "_1_offset")) + stdv = 1.0 / math.sqrt(med_ch * 1.0) + self.excitation = Linear( + med_ch, + num_channels, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name=name + "_2_weights"), + bias_attr=ParamAttr(name=name + "_2_offset")) + + def forward(self, inputs): + pool = self.pool2d_gap(inputs) + pool = paddle.reshape(pool, shape=[-1, self._num_channels]) + squeeze = self.squeeze(pool) + squeeze = F.relu(squeeze) + excitation = self.excitation(squeeze) + excitation = paddle.fluid.layers.clip(x=excitation, min=0, max=1) + excitation = paddle.reshape( + excitation, shape=[-1, self._num_channels, 1, 1]) + out = inputs * excitation + return out + + +class GhostModule(nn.Layer): + def __init__(self, + in_channels, + output_channels, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + relu=True, + name=None): + super(GhostModule, self).__init__() + init_channels = int(math.ceil(output_channels / ratio)) + new_channels = int(init_channels * (ratio - 1)) + self.primary_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=init_channels, + kernel_size=kernel_size, + stride=stride, + groups=1, + act="relu" if relu else None, + name=name + "_primary_conv") + self.cheap_operation = ConvBNLayer( + in_channels=init_channels, + out_channels=new_channels, + kernel_size=dw_size, + stride=1, + groups=init_channels, + act="relu" if relu else None, + name=name + "_cheap_operation") + + def forward(self, inputs): + x = self.primary_conv(inputs) + y = self.cheap_operation(x) + out = paddle.concat([x, y], axis=1) + return out + + +class GhostBottleneck(nn.Layer): + def __init__(self, + in_channels, + hidden_dim, + output_channels, + kernel_size, + stride, + use_se, + name=None): + super(GhostBottleneck, self).__init__() + self._stride = stride + self._use_se = use_se + self._num_channels = in_channels + self._output_channels = output_channels + self.ghost_module_1 = GhostModule( + in_channels=in_channels, + output_channels=hidden_dim, + kernel_size=1, + stride=1, + relu=True, + name=name + "_ghost_module_1") + if stride == 2: + self.depthwise_conv = ConvBNLayer( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=kernel_size, + stride=stride, + groups=hidden_dim, + act=None, + name=name + + "_depthwise_depthwise" # looks strange due to an old typo, will be fixed later. + ) + if use_se: + self.se_block = SEBlock(num_channels=hidden_dim, name=name + "_se") + self.ghost_module_2 = GhostModule( + in_channels=hidden_dim, + output_channels=output_channels, + kernel_size=1, + relu=False, + name=name + "_ghost_module_2") + if stride != 1 or in_channels != output_channels: + self.shortcut_depthwise = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + act=None, + name=name + + "_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later. + ) + self.shortcut_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=output_channels, + kernel_size=1, + stride=1, + groups=1, + act=None, + name=name + "_shortcut_conv") + + def forward(self, inputs): + x = self.ghost_module_1(inputs) + if self._stride == 2: + x = self.depthwise_conv(x) + if self._use_se: + x = self.se_block(x) + x = self.ghost_module_2(x) + if self._stride == 1 and self._num_channels == self._output_channels: + shortcut = inputs + else: + shortcut = self.shortcut_depthwise(inputs) + shortcut = self.shortcut_conv(shortcut) + return paddle.elementwise_add(x=x, y=shortcut, axis=-1) + + +class GhostNet(nn.Layer): + def __init__(self, scale, class_dim=1000): + super(GhostNet, self).__init__() + self.cfgs = [ + # k, t, c, SE, s + [3, 16, 16, 0, 1], + [3, 48, 24, 0, 2], + [3, 72, 24, 0, 1], + [5, 72, 40, 1, 2], + [5, 120, 40, 1, 1], + [3, 240, 80, 0, 2], + [3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 1, 1], + [3, 672, 112, 1, 1], + [5, 672, 160, 1, 2], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1] + ] + self.scale = scale + output_channels = int(self._make_divisible(16 * self.scale, 4)) + self.conv1 = ConvBNLayer( + in_channels=3, + out_channels=output_channels, + kernel_size=3, + stride=2, + groups=1, + act="relu", + name="conv1") + # build inverted residual blocks + idx = 0 + self.ghost_bottleneck_list = [] + for k, exp_size, c, use_se, s in self.cfgs: + in_channels = output_channels + output_channels = int(self._make_divisible(c * self.scale, 4)) + hidden_dim = int(self._make_divisible(exp_size * self.scale, 4)) + ghost_bottleneck = self.add_sublayer( + name="_ghostbottleneck_" + str(idx), + sublayer=GhostBottleneck( + in_channels=in_channels, + hidden_dim=hidden_dim, + output_channels=output_channels, + kernel_size=k, + stride=s, + use_se=use_se, + name="_ghostbottleneck_" + str(idx))) + self.ghost_bottleneck_list.append(ghost_bottleneck) + idx += 1 + # build last several layers + in_channels = output_channels + output_channels = int(self._make_divisible(exp_size * self.scale, 4)) + self.conv_last = ConvBNLayer( + in_channels=in_channels, + out_channels=output_channels, + kernel_size=1, + stride=1, + groups=1, + act="relu", + name="conv_last") + self.pool2d_gap = AdaptiveAvgPool2d(1) + in_channels = output_channels + self._fc0_output_channels = 1280 + self.fc_0 = ConvBNLayer( + in_channels=in_channels, + out_channels=self._fc0_output_channels, + kernel_size=1, + stride=1, + act="relu", + name="fc_0") + self.dropout = nn.Dropout(p=0.2) + stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0) + self.fc_1 = Linear( + self._fc0_output_channels, + class_dim, + weight_attr=ParamAttr( + name="fc_1_weights", initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(name="fc_1_offset")) + + def forward(self, inputs): + x = self.conv1(inputs) + for ghost_bottleneck in self.ghost_bottleneck_list: + x = ghost_bottleneck(x) + x = self.conv_last(x) + x = self.pool2d_gap(x) + x = self.fc_0(x) + x = self.dropout(x) + x = paddle.reshape(x, shape=[-1, self._fc0_output_channels]) + x = self.fc_1(x) + return x + + def _make_divisible(self, v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def GhostNet_x0_5(**args): + model = GhostNet(scale=0.5) + return model + + +def GhostNet_x1_0(**args): + model = GhostNet(scale=1.0) + return model + + +def GhostNet_x1_3(**args): + model = GhostNet(scale=1.3) + return model diff --git a/ppcls/modeling/architectures/shufflenet_v2.py b/ppcls/modeling/architectures/shufflenet_v2.py index c0dfe2c2c3214214bf922ca868ef61f8d6e8294f..9e06c955c3cfb344fb7a7d1b25e2c7443b7f37bb 100644 --- a/ppcls/modeling/architectures/shufflenet_v2.py +++ b/ppcls/modeling/architectures/shufflenet_v2.py @@ -16,15 +16,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -import paddle -from paddle import ParamAttr -import paddle.nn as nn -import paddle.nn.functional as F -from paddle.nn import Conv2d, BatchNorm, Linear, Dropout -from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d +from paddle import ParamAttr, reshape, transpose, concat, split +from paddle.nn import Layer, Conv2d, MaxPool2d, AdaptiveAvgPool2d, BatchNorm, Linear from paddle.nn.initializer import MSRA -import math +from paddle.nn.functional import swish __all__ = [ "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5", @@ -34,188 +29,176 @@ __all__ = [ def channel_shuffle(x, groups): - batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[ - 2], x.shape[3] + batch_size, num_channels, height, width = x.shape[0:4] channels_per_group = num_channels // groups # reshape - x = paddle.reshape( - x=x, shape=[batchsize, groups, channels_per_group, height, width]) + x = reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + + # transpose + x = transpose(x=x, perm=[0, 2, 1, 3, 4]) - x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) # flatten - x = paddle.reshape(x=x, shape=[batchsize, num_channels, height, width]) + x = reshape(x=x, shape=[batch_size, num_channels, height, width]) return x -class ConvBNLayer(nn.Layer): - def __init__(self, - num_channels, - filter_size, - num_filters, - stride, - padding, - channels=None, - num_groups=1, - if_act=True, - act='relu', - name=None): +class ConvBNLayer(Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + act=None, + name=None, ): super(ConvBNLayer, self).__init__() - self._if_act = if_act - assert act in ['relu', 'swish'], \ - "supported act are {} but your act is {}".format( - ['relu', 'swish'], act) - self._act = act self._conv = Conv2d( - in_channels=num_channels, - out_channels=num_filters, - kernel_size=filter_size, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, - groups=num_groups, + groups=groups, weight_attr=ParamAttr( initializer=MSRA(), name=name + "_weights"), bias_attr=False) self._batch_norm = BatchNorm( - num_filters, + out_channels, param_attr=ParamAttr(name=name + "_bn_scale"), bias_attr=ParamAttr(name=name + "_bn_offset"), + act=act, moving_mean_name=name + "_bn_mean", moving_variance_name=name + "_bn_variance") - def forward(self, inputs, if_act=True): + def forward(self, inputs): y = self._conv(inputs) y = self._batch_norm(y) - if self._if_act: - y = F.relu(y) if self._act == 'relu' else F.swish(y) return y -class InvertedResidualUnit(nn.Layer): +class InvertedResidual(Layer): def __init__(self, - num_channels, - num_filters, + in_channels, + out_channels, stride, - benchmodel, - act='relu', + act="relu", name=None): - super(InvertedResidualUnit, self).__init__() - assert stride in [1, 2], \ - "supported stride are {} but your stride is {}".format([ - 1, 2], stride) - self.benchmodel = benchmodel - oup_inc = num_filters // 2 - inp = num_channels - if benchmodel == 1: - self._conv_pw = ConvBNLayer( - num_channels=num_channels // 2, - num_filters=oup_inc, - filter_size=1, - stride=1, - padding=0, - num_groups=1, - if_act=True, - act=act, - name='stage_' + name + '_conv1') - self._conv_dw = ConvBNLayer( - num_channels=oup_inc, - num_filters=oup_inc, - filter_size=3, - stride=stride, - padding=1, - num_groups=oup_inc, - if_act=False, - act=act, - name='stage_' + name + '_conv2') - self._conv_linear = ConvBNLayer( - num_channels=oup_inc, - num_filters=oup_inc, - filter_size=1, - stride=1, - padding=0, - num_groups=1, - if_act=True, - act=act, - name='stage_' + name + '_conv3') - else: - # branch1 - self._conv_dw_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=inp, - filter_size=3, - stride=stride, - padding=1, - num_groups=inp, - if_act=False, - act=act, - name='stage_' + name + '_conv4') - self._conv_linear_1 = ConvBNLayer( - num_channels=inp, - num_filters=oup_inc, - filter_size=1, - stride=1, - padding=0, - num_groups=1, - if_act=True, - act=act, - name='stage_' + name + '_conv5') - # branch2 - self._conv_pw_2 = ConvBNLayer( - num_channels=num_channels, - num_filters=oup_inc, - filter_size=1, - stride=1, - padding=0, - num_groups=1, - if_act=True, - act=act, - name='stage_' + name + '_conv1') - self._conv_dw_2 = ConvBNLayer( - num_channels=oup_inc, - num_filters=oup_inc, - filter_size=3, - stride=stride, - padding=1, - num_groups=oup_inc, - if_act=False, - act=act, - name='stage_' + name + '_conv2') - self._conv_linear_2 = ConvBNLayer( - num_channels=oup_inc, - num_filters=oup_inc, - filter_size=1, - stride=1, - padding=0, - num_groups=1, - if_act=True, - act=act, - name='stage_' + name + '_conv3') + super(InvertedResidual, self).__init__() + self._conv_pw = ConvBNLayer( + in_channels=in_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv1') + self._conv_dw = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=out_channels // 2, + act=None, + name='stage_' + name + '_conv2') + self._conv_linear = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv3') def forward(self, inputs): - if self.benchmodel == 1: - x1, x2 = paddle.split( - inputs, - num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], - axis=1) - x2 = self._conv_pw(x2) - x2 = self._conv_dw(x2) - x2 = self._conv_linear(x2) - out = paddle.concat([x1, x2], axis=1) - else: - x1 = self._conv_dw_1(inputs) - x1 = self._conv_linear_1(x1) + x1, x2 = split( + inputs, + num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], + axis=1) + x2 = self._conv_pw(x2) + x2 = self._conv_dw(x2) + x2 = self._conv_linear(x2) + out = concat([x1, x2], axis=1) + return channel_shuffle(out, 2) - x2 = self._conv_pw_2(inputs) - x2 = self._conv_dw_2(x2) - x2 = self._conv_linear_2(x2) - out = paddle.concat([x1, x2], axis=1) + +class InvertedResidualDS(Layer): + def __init__(self, + in_channels, + out_channels, + stride, + act="relu", + name=None): + super(InvertedResidualDS, self).__init__() + + # branch1 + self._conv_dw_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + act=None, + name='stage_' + name + '_conv4') + self._conv_linear_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv5') + # branch2 + self._conv_pw_2 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv1') + self._conv_dw_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=out_channels // 2, + act=None, + name='stage_' + name + '_conv2') + self._conv_linear_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv3') + + def forward(self, inputs): + x1 = self._conv_dw_1(inputs) + x1 = self._conv_linear_1(x1) + x2 = self._conv_pw_2(inputs) + x2 = self._conv_dw_2(x2) + x2 = self._conv_linear_2(x2) + out = concat([x1, x2], axis=1) return channel_shuffle(out, 2) -class ShuffleNet(nn.Layer): - def __init__(self, class_dim=1000, scale=1.0, act='relu'): +class ShuffleNet(Layer): + def __init__(self, class_dim=1000, scale=1.0, act="relu"): super(ShuffleNet, self).__init__() self.scale = scale self.class_dim = class_dim @@ -238,58 +221,47 @@ class ShuffleNet(nn.Layer): "] is not implemented!") # 1. conv1 self._conv1 = ConvBNLayer( - num_channels=3, - num_filters=stage_out_channels[1], - filter_size=3, + in_channels=3, + out_channels=stage_out_channels[1], + kernel_size=3, stride=2, padding=1, - if_act=True, act=act, name='stage1_conv') self._max_pool = MaxPool2d(kernel_size=3, stride=2, padding=1) # 2. bottleneck sequences self._block_list = [] - i = 1 - in_c = int(32 * scale) - for idxstage in range(len(stage_repeats)): - numrepeat = stage_repeats[idxstage] - output_channel = stage_out_channels[idxstage + 2] - for i in range(numrepeat): + for stage_id, num_repeat in enumerate(stage_repeats): + for i in range(num_repeat): if i == 0: block = self.add_sublayer( - str(idxstage + 2) + '_' + str(i + 1), - InvertedResidualUnit( - num_channels=stage_out_channels[idxstage + 1], - num_filters=output_channel, + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidualDS( + in_channels=stage_out_channels[stage_id + 1], + out_channels=stage_out_channels[stage_id + 2], stride=2, - benchmodel=2, act=act, - name=str(idxstage + 2) + '_' + str(i + 1))) - self._block_list.append(block) + name=str(stage_id + 2) + '_' + str(i + 1))) else: block = self.add_sublayer( - str(idxstage + 2) + '_' + str(i + 1), - InvertedResidualUnit( - num_channels=output_channel, - num_filters=output_channel, + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidual( + in_channels=stage_out_channels[stage_id + 2], + out_channels=stage_out_channels[stage_id + 2], stride=1, - benchmodel=1, act=act, - name=str(idxstage + 2) + '_' + str(i + 1))) - self._block_list.append(block) - + name=str(stage_id + 2) + '_' + str(i + 1))) + self._block_list.append(block) # 3. last_conv self._last_conv = ConvBNLayer( - num_channels=stage_out_channels[-2], - num_filters=stage_out_channels[-1], - filter_size=1, + in_channels=stage_out_channels[-2], + out_channels=stage_out_channels[-1], + kernel_size=1, stride=1, padding=0, - if_act=True, act=act, name='conv5') - # 4. pool self._pool2d_avg = AdaptiveAvgPool2d(1) self._out_c = stage_out_channels[-1] @@ -307,13 +279,13 @@ class ShuffleNet(nn.Layer): y = inv(y) y = self._last_conv(y) y = self._pool2d_avg(y) - y = paddle.reshape(y, shape=[-1, self._out_c]) + y = reshape(y, shape=[-1, self._out_c]) y = self._fc(y) return y def ShuffleNetV2_x0_25(**args): - model = ShuffleNetV2(scale=0.25, **args) + model = ShuffleNet(scale=0.25, **args) return model @@ -343,5 +315,5 @@ def ShuffleNetV2_x2_0(**args): def ShuffleNetV2_swish(**args): - model = ShuffleNet(scale=1.0, act='swish', **args) + model = ShuffleNet(scale=1.0, act="swish", **args) return model