From f6219dda46e920efa2c37323961a8927f39a54d8 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Sat, 23 Apr 2022 07:27:54 +0800 Subject: [PATCH] reuse ConvNormActivation in some vision models (#40431) * reuse ConvNormActivation in some vision models --- python/paddle/vision/models/inceptionv3.py | 477 ++++++++++---------- python/paddle/vision/models/mobilenetv1.py | 56 +-- python/paddle/vision/models/mobilenetv2.py | 89 ++-- python/paddle/vision/models/shufflenetv2.py | 124 +++-- python/paddle/vision/ops.py | 8 +- 5 files changed, 372 insertions(+), 382 deletions(-) diff --git a/python/paddle/vision/models/inceptionv3.py b/python/paddle/vision/models/inceptionv3.py index 9e8a8b8146..27650dbe09 100644 --- a/python/paddle/vision/models/inceptionv3.py +++ b/python/paddle/vision/models/inceptionv3.py @@ -19,75 +19,60 @@ from __future__ import print_function import math import paddle import paddle.nn as nn -from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import Linear, Dropout from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn.initializer import Uniform from paddle.fluid.param_attr import ParamAttr from paddle.utils.download import get_weights_path_from_url +from ..ops import ConvNormActivation __all__ = [] model_urls = { "inception_v3": - ("https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams", - "e4d0905a818f6bb7946e881777a8a935") + ("https://paddle-hapi.bj.bcebos.com/models/inception_v3.pdparams", + "649a4547c3243e8b59c656f41fe330b8") } -class ConvBNLayer(nn.Layer): - def __init__(self, - num_channels, - num_filters, - filter_size, - stride=1, - padding=0, - groups=1, - act="relu"): - super().__init__() - self.act = act - self.conv = Conv2D( - in_channels=num_channels, - out_channels=num_filters, - kernel_size=filter_size, - stride=stride, - padding=padding, - groups=groups, - bias_attr=False) - self.bn = BatchNorm(num_filters) - self.relu = nn.ReLU() - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - if self.act: - x = self.relu(x) - return x - - class InceptionStem(nn.Layer): def __init__(self): super().__init__() - self.conv_1a_3x3 = ConvBNLayer( - num_channels=3, num_filters=32, filter_size=3, stride=2, act="relu") - self.conv_2a_3x3 = ConvBNLayer( - num_channels=32, - num_filters=32, - filter_size=3, + self.conv_1a_3x3 = ConvNormActivation( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=2, + padding=0, + activation_layer=nn.ReLU) + self.conv_2a_3x3 = ConvNormActivation( + in_channels=32, + out_channels=32, + kernel_size=3, stride=1, - act="relu") - self.conv_2b_3x3 = ConvBNLayer( - num_channels=32, - num_filters=64, - filter_size=3, + padding=0, + activation_layer=nn.ReLU) + self.conv_2b_3x3 = ConvNormActivation( + in_channels=32, + out_channels=64, + kernel_size=3, padding=1, - act="relu") + activation_layer=nn.ReLU) self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0) - self.conv_3b_1x1 = ConvBNLayer( - num_channels=64, num_filters=80, filter_size=1, act="relu") - self.conv_4a_3x3 = ConvBNLayer( - num_channels=80, num_filters=192, filter_size=3, act="relu") + self.conv_3b_1x1 = ConvNormActivation( + in_channels=64, + out_channels=80, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.conv_4a_3x3 = ConvNormActivation( + in_channels=80, + out_channels=192, + kernel_size=3, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): x = self.conv_1a_3x3(x) @@ -103,47 +88,53 @@ class InceptionStem(nn.Layer): class InceptionA(nn.Layer): def __init__(self, num_channels, pool_features): super().__init__() - self.branch1x1 = ConvBNLayer( - num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch5x5_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=48, - filter_size=1, - act="relu") - self.branch5x5_2 = ConvBNLayer( - num_channels=48, - num_filters=64, - filter_size=5, + self.branch1x1 = ConvNormActivation( + in_channels=num_channels, + out_channels=64, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + + self.branch5x5_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=48, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch5x5_2 = ConvNormActivation( + in_channels=48, + out_channels=64, + kernel_size=5, padding=2, - act="relu") - - self.branch3x3dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer( - num_channels=64, - num_filters=96, - filter_size=3, + activation_layer=nn.ReLU) + + self.branch3x3dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=64, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3dbl_2 = ConvNormActivation( + in_channels=64, + out_channels=96, + kernel_size=3, padding=1, - act="relu") - self.branch3x3dbl_3 = ConvBNLayer( - num_channels=96, - num_filters=96, - filter_size=3, + activation_layer=nn.ReLU) + self.branch3x3dbl_3 = ConvNormActivation( + in_channels=96, + out_channels=96, + kernel_size=3, padding=1, - act="relu") + activation_layer=nn.ReLU) + self.branch_pool = AvgPool2D( kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=pool_features, - filter_size=1, - act="relu") + self.branch_pool_conv = ConvNormActivation( + in_channels=num_channels, + out_channels=pool_features, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -164,29 +155,34 @@ class InceptionA(nn.Layer): class InceptionB(nn.Layer): def __init__(self, num_channels): super().__init__() - self.branch3x3 = ConvBNLayer( - num_channels=num_channels, - num_filters=384, - filter_size=3, + self.branch3x3 = ConvNormActivation( + in_channels=num_channels, + out_channels=384, + kernel_size=3, stride=2, - act="relu") - self.branch3x3dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer( - num_channels=64, - num_filters=96, - filter_size=3, + padding=0, + activation_layer=nn.ReLU) + + self.branch3x3dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=64, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3dbl_2 = ConvNormActivation( + in_channels=64, + out_channels=96, + kernel_size=3, padding=1, - act="relu") - self.branch3x3dbl_3 = ConvBNLayer( - num_channels=96, - num_filters=96, - filter_size=3, + activation_layer=nn.ReLU) + self.branch3x3dbl_3 = ConvNormActivation( + in_channels=96, + out_channels=96, + kernel_size=3, stride=2, - act="relu") + padding=0, + activation_layer=nn.ReLU) + self.branch_pool = MaxPool2D(kernel_size=3, stride=2) def forward(self, x): @@ -206,70 +202,74 @@ class InceptionB(nn.Layer): class InceptionC(nn.Layer): def __init__(self, num_channels, channels_7x7): super().__init__() - self.branch1x1 = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - - self.branch7x7_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=channels_7x7, - filter_size=1, + self.branch1x1 = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + + self.branch7x7_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=channels_7x7, + kernel_size=1, stride=1, - act="relu") - self.branch7x7_2 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(1, 7), + padding=0, + activation_layer=nn.ReLU) + self.branch7x7_2 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(1, 7), stride=1, padding=(0, 3), - act="relu") - self.branch7x7_3 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=192, - filter_size=(7, 1), + activation_layer=nn.ReLU) + self.branch7x7_3 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=192, + kernel_size=(7, 1), stride=1, padding=(3, 0), - act="relu") - - self.branch7x7dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=channels_7x7, - filter_size=1, - act="relu") - self.branch7x7dbl_2 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(7, 1), + activation_layer=nn.ReLU) + + self.branch7x7dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=channels_7x7, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch7x7dbl_2 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(7, 1), padding=(3, 0), - act="relu") - self.branch7x7dbl_3 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(1, 7), + activation_layer=nn.ReLU) + self.branch7x7dbl_3 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(1, 7), padding=(0, 3), - act="relu") - self.branch7x7dbl_4 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(7, 1), + activation_layer=nn.ReLU) + self.branch7x7dbl_4 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(7, 1), padding=(3, 0), - act="relu") - self.branch7x7dbl_5 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=192, - filter_size=(1, 7), + activation_layer=nn.ReLU) + self.branch7x7dbl_5 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=192, + kernel_size=(1, 7), padding=(0, 3), - act="relu") + activation_layer=nn.ReLU) self.branch_pool = AvgPool2D( kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") + self.branch_pool_conv = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -296,40 +296,46 @@ class InceptionC(nn.Layer): class InceptionD(nn.Layer): def __init__(self, num_channels): super().__init__() - self.branch3x3_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - self.branch3x3_2 = ConvBNLayer( - num_channels=192, - num_filters=320, - filter_size=3, + self.branch3x3_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3_2 = ConvNormActivation( + in_channels=192, + out_channels=320, + kernel_size=3, stride=2, - act="relu") - self.branch7x7x3_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - self.branch7x7x3_2 = ConvBNLayer( - num_channels=192, - num_filters=192, - filter_size=(1, 7), + padding=0, + activation_layer=nn.ReLU) + + self.branch7x7x3_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch7x7x3_2 = ConvNormActivation( + in_channels=192, + out_channels=192, + kernel_size=(1, 7), padding=(0, 3), - act="relu") - self.branch7x7x3_3 = ConvBNLayer( - num_channels=192, - num_filters=192, - filter_size=(7, 1), + activation_layer=nn.ReLU) + self.branch7x7x3_3 = ConvNormActivation( + in_channels=192, + out_channels=192, + kernel_size=(7, 1), padding=(3, 0), - act="relu") - self.branch7x7x3_4 = ConvBNLayer( - num_channels=192, - num_filters=192, - filter_size=3, + activation_layer=nn.ReLU) + self.branch7x7x3_4 = ConvNormActivation( + in_channels=192, + out_channels=192, + kernel_size=3, stride=2, - act="relu") + padding=0, + activation_layer=nn.ReLU) + self.branch_pool = MaxPool2D(kernel_size=3, stride=2) def forward(self, x): @@ -350,59 +356,64 @@ class InceptionD(nn.Layer): class InceptionE(nn.Layer): def __init__(self, num_channels): super().__init__() - self.branch1x1 = ConvBNLayer( - num_channels=num_channels, - num_filters=320, - filter_size=1, - act="relu") - self.branch3x3_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=384, - filter_size=1, - act="relu") - self.branch3x3_2a = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(1, 3), + self.branch1x1 = ConvNormActivation( + in_channels=num_channels, + out_channels=320, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=384, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3_2a = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(1, 3), padding=(0, 1), - act="relu") - self.branch3x3_2b = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(3, 1), + activation_layer=nn.ReLU) + self.branch3x3_2b = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(3, 1), padding=(1, 0), - act="relu") - - self.branch3x3dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=448, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer( - num_channels=448, - num_filters=384, - filter_size=3, + activation_layer=nn.ReLU) + + self.branch3x3dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=448, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3dbl_2 = ConvNormActivation( + in_channels=448, + out_channels=384, + kernel_size=3, padding=1, - act="relu") - self.branch3x3dbl_3a = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(1, 3), + activation_layer=nn.ReLU) + self.branch3x3dbl_3a = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(1, 3), padding=(0, 1), - act="relu") - self.branch3x3dbl_3b = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(3, 1), + activation_layer=nn.ReLU) + self.branch3x3dbl_3b = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(3, 1), padding=(1, 0), - act="relu") + activation_layer=nn.ReLU) + self.branch_pool = AvgPool2D( kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") + self.branch_pool_conv = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): branch1x1 = self.branch1x1(x) diff --git a/python/paddle/vision/models/mobilenetv1.py b/python/paddle/vision/models/mobilenetv1.py index 671a2cd8df..6d8d96952f 100644 --- a/python/paddle/vision/models/mobilenetv1.py +++ b/python/paddle/vision/models/mobilenetv1.py @@ -16,59 +16,31 @@ import paddle import paddle.nn as nn from paddle.utils.download import get_weights_path_from_url +from ..ops import ConvNormActivation __all__ = [] model_urls = { 'mobilenetv1_1.0': - ('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams', - '42a154c2f26f86e7457d6daded114e8c') + ('https://paddle-hapi.bj.bcebos.com/models/mobilenetv1_1.0.pdparams', + '3033ab1975b1670bef51545feb65fc45') } -class ConvBNLayer(nn.Layer): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - num_groups=1): - super(ConvBNLayer, self).__init__() - - self._conv = nn.Conv2D( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - groups=num_groups, - bias_attr=False) - - self._norm_layer = nn.BatchNorm2D(out_channels) - self._act = nn.ReLU() - - def forward(self, x): - x = self._conv(x) - x = self._norm_layer(x) - x = self._act(x) - return x - - class DepthwiseSeparable(nn.Layer): def __init__(self, in_channels, out_channels1, out_channels2, num_groups, stride, scale): super(DepthwiseSeparable, self).__init__() - self._depthwise_conv = ConvBNLayer( + self._depthwise_conv = ConvNormActivation( in_channels, int(out_channels1 * scale), kernel_size=3, stride=stride, padding=1, - num_groups=int(num_groups * scale)) + groups=int(num_groups * scale)) - self._pointwise_conv = ConvBNLayer( + self._pointwise_conv = ConvNormActivation( int(out_channels1 * scale), int(out_channels2 * scale), kernel_size=1, @@ -94,9 +66,15 @@ class MobileNetV1(nn.Layer): Examples: .. code-block:: python + import paddle from paddle.vision.models import MobileNetV1 model = MobileNetV1() + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) """ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): @@ -106,7 +84,7 @@ class MobileNetV1(nn.Layer): self.num_classes = num_classes self.with_pool = with_pool - self.conv1 = ConvBNLayer( + self.conv1 = ConvNormActivation( in_channels=3, out_channels=int(32 * scale), kernel_size=3, @@ -257,6 +235,7 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v1 # build model @@ -266,7 +245,12 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): # model = mobilenet_v1(pretrained=True) # build mobilenet v1 with scale=0.5 - model = mobilenet_v1(scale=0.5) + model_scale = mobilenet_v1(scale=0.5) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) """ model = _mobilenet( 'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs) diff --git a/python/paddle/vision/models/mobilenetv2.py b/python/paddle/vision/models/mobilenetv2.py index 6c486037c7..9791462610 100644 --- a/python/paddle/vision/models/mobilenetv2.py +++ b/python/paddle/vision/models/mobilenetv2.py @@ -17,6 +17,7 @@ import paddle.nn as nn from paddle.utils.download import get_weights_path_from_url from .utils import _make_divisible +from ..ops import ConvNormActivation __all__ = [] @@ -27,29 +28,6 @@ model_urls = { } -class ConvBNReLU(nn.Sequential): - def __init__(self, - in_planes, - out_planes, - kernel_size=3, - stride=1, - groups=1, - norm_layer=nn.BatchNorm2D): - padding = (kernel_size - 1) // 2 - - super(ConvBNReLU, self).__init__( - nn.Conv2D( - in_planes, - out_planes, - kernel_size, - stride, - padding, - groups=groups, - bias_attr=False), - norm_layer(out_planes), - nn.ReLU6()) - - class InvertedResidual(nn.Layer): def __init__(self, inp, @@ -67,15 +45,20 @@ class InvertedResidual(nn.Layer): layers = [] if expand_ratio != 1: layers.append( - ConvBNReLU( - inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) + ConvNormActivation( + inp, + hidden_dim, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.ReLU6)) layers.extend([ - ConvBNReLU( + ConvNormActivation( hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, - norm_layer=norm_layer), + norm_layer=norm_layer, + activation_layer=nn.ReLU6), nn.Conv2D( hidden_dim, oup, 1, 1, 0, bias_attr=False), norm_layer(oup), @@ -90,23 +73,30 @@ class InvertedResidual(nn.Layer): class MobileNetV2(nn.Layer): - def __init__(self, scale=1.0, num_classes=1000, with_pool=True): - """MobileNetV2 model from - `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + """MobileNetV2 model from + `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + + Args: + scale (float): scale of channels in each layer. Default: 1.0. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. - Args: - scale (float): scale of channels in each layer. Default: 1.0. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer - will not be defined. Default: 1000. - with_pool (bool): use pool before the last fc layer or not. Default: True. + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import MobileNetV2 - Examples: - .. code-block:: python + model = MobileNetV2() - from paddle.vision.models import MobileNetV2 + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + """ - model = MobileNetV2() - """ + def __init__(self, scale=1.0, num_classes=1000, with_pool=True): super(MobileNetV2, self).__init__() self.num_classes = num_classes self.with_pool = with_pool @@ -130,8 +120,12 @@ class MobileNetV2(nn.Layer): self.last_channel = _make_divisible(last_channel * max(1.0, scale), round_nearest) features = [ - ConvBNReLU( - 3, input_channel, stride=2, norm_layer=norm_layer) + ConvNormActivation( + 3, + input_channel, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.ReLU6) ] for t, c, n, s in inverted_residual_setting: @@ -148,11 +142,12 @@ class MobileNetV2(nn.Layer): input_channel = output_channel features.append( - ConvBNReLU( + ConvNormActivation( input_channel, self.last_channel, kernel_size=1, - norm_layer=norm_layer)) + norm_layer=norm_layer, + activation_layer=nn.ReLU6)) self.features = nn.Sequential(*features) @@ -199,6 +194,7 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v2 # build model @@ -209,6 +205,11 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): # build mobilenet v2 with scale=0.5 model = mobilenet_v2(scale=0.5) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) """ model = _mobilenet( 'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs) diff --git a/python/paddle/vision/models/shufflenetv2.py b/python/paddle/vision/models/shufflenetv2.py index 041f3fc749..90e967ee22 100644 --- a/python/paddle/vision/models/shufflenetv2.py +++ b/python/paddle/vision/models/shufflenetv2.py @@ -18,37 +18,50 @@ from __future__ import print_function import paddle import paddle.nn as nn -from paddle.fluid.param_attr import ParamAttr -from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear, MaxPool2D +from paddle.nn import AdaptiveAvgPool2D, Linear, MaxPool2D from paddle.utils.download import get_weights_path_from_url +from ..ops import ConvNormActivation + __all__ = [] model_urls = { "shufflenet_v2_x0_25": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_25_pretrained.pdparams", - "e753404cbd95027759c5f56ecd6c9c4b", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_25.pdparams", + "1e509b4c140eeb096bb16e214796d03b", ), "shufflenet_v2_x0_33": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_33_pretrained.pdparams", - "776e3cf9a4923abdfce789c45b8fe1f2", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_33.pdparams", + "3d7b3ab0eaa5c0927ff1026d31b729bd", ), "shufflenet_v2_x0_5": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_5_pretrained.pdparams", - "e3649cf531566917e2969487d2bc6b60", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_5.pdparams", + "5e5cee182a7793c4e4c73949b1a71bd4", ), "shufflenet_v2_x1_0": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_0_pretrained.pdparams", - "7821c348ea34e58847c43a08a4ac0bdf", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x1_0.pdparams", + "122d42478b9e81eb49f8a9ede327b1a4", ), "shufflenet_v2_x1_5": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_5_pretrained.pdparams", - "93a07fa557ab2d8803550f39e5b6c391", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x1_5.pdparams", + "faced5827380d73531d0ee027c67826d", ), "shufflenet_v2_x2_0": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x2_0_pretrained.pdparams", - "4ab1f622fd0d341e0f84b4e057797563", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x2_0.pdparams", + "cd3dddcd8305e7bcd8ad14d1c69a5784", ), "shufflenet_v2_swish": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_swish_pretrained.pdparams", - "daff38b3df1b3748fccbb13cfdf02519", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_swish.pdparams", + "adde0aa3b023e5b0c94a68be1c394b84", ), } +def create_activation_layer(act): + if act == "swish": + return nn.Swish + elif act == "relu": + return nn.ReLU + elif act is None: + return None + else: + raise RuntimeError( + "The activation function is not supported: {}".format(act)) + + def channel_shuffle(x, groups): batch_size, num_channels, height, width = x.shape[0:4] channels_per_group = num_channels // groups @@ -65,61 +78,37 @@ def channel_shuffle(x, groups): return x -class ConvBNLayer(nn.Layer): +class InvertedResidual(nn.Layer): def __init__(self, in_channels, out_channels, - kernel_size, stride, - padding, - groups=1, - act=None): - super(ConvBNLayer, self).__init__() - self._conv = Conv2D( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=groups, - weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()), - bias_attr=False, ) - - self._batch_norm = BatchNorm(out_channels, act=act) - - def forward(self, inputs): - x = self._conv(inputs) - x = self._batch_norm(x) - return x - - -class InvertedResidual(nn.Layer): - def __init__(self, in_channels, out_channels, stride, act="relu"): + activation_layer=nn.ReLU): super(InvertedResidual, self).__init__() - self._conv_pw = ConvBNLayer( + self._conv_pw = ConvNormActivation( in_channels=in_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) - self._conv_dw = ConvBNLayer( + activation_layer=activation_layer) + self._conv_dw = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=3, stride=stride, padding=1, groups=out_channels // 2, - act=None) - self._conv_linear = ConvBNLayer( + activation_layer=None) + self._conv_linear = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) + activation_layer=activation_layer) def forward(self, inputs): x1, x2 = paddle.split( @@ -134,51 +123,55 @@ class InvertedResidual(nn.Layer): class InvertedResidualDS(nn.Layer): - def __init__(self, in_channels, out_channels, stride, act="relu"): + def __init__(self, + in_channels, + out_channels, + stride, + activation_layer=nn.ReLU): super(InvertedResidualDS, self).__init__() # branch1 - self._conv_dw_1 = ConvBNLayer( + self._conv_dw_1 = ConvNormActivation( in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, - act=None) - self._conv_linear_1 = ConvBNLayer( + activation_layer=None) + self._conv_linear_1 = ConvNormActivation( in_channels=in_channels, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) + activation_layer=activation_layer) # branch2 - self._conv_pw_2 = ConvBNLayer( + self._conv_pw_2 = ConvNormActivation( in_channels=in_channels, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) - self._conv_dw_2 = ConvBNLayer( + activation_layer=activation_layer) + self._conv_dw_2 = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=3, stride=stride, padding=1, groups=out_channels // 2, - act=None) - self._conv_linear_2 = ConvBNLayer( + activation_layer=None) + self._conv_linear_2 = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) + activation_layer=activation_layer) def forward(self, inputs): x1 = self._conv_dw_1(inputs) @@ -221,6 +214,7 @@ class ShuffleNetV2(nn.Layer): self.num_classes = num_classes self.with_pool = with_pool stage_repeats = [4, 8, 4] + activation_layer = create_activation_layer(act) if scale == 0.25: stage_out_channels = [-1, 24, 24, 48, 96, 512] @@ -238,13 +232,13 @@ class ShuffleNetV2(nn.Layer): raise NotImplementedError("This scale size:[" + str(scale) + "] is not implemented!") # 1. conv1 - self._conv1 = ConvBNLayer( + self._conv1 = ConvNormActivation( in_channels=3, out_channels=stage_out_channels[1], kernel_size=3, stride=2, padding=1, - act=act) + activation_layer=activation_layer) self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) # 2. bottleneck sequences @@ -257,7 +251,7 @@ class ShuffleNetV2(nn.Layer): in_channels=stage_out_channels[stage_id + 1], out_channels=stage_out_channels[stage_id + 2], stride=2, - act=act), + activation_layer=activation_layer), name=str(stage_id + 2) + "_" + str(i + 1)) else: block = self.add_sublayer( @@ -265,17 +259,17 @@ class ShuffleNetV2(nn.Layer): in_channels=stage_out_channels[stage_id + 2], out_channels=stage_out_channels[stage_id + 2], stride=1, - act=act), + activation_layer=activation_layer), name=str(stage_id + 2) + "_" + str(i + 1)) self._block_list.append(block) # 3. last_conv - self._last_conv = ConvBNLayer( + self._last_conv = ConvNormActivation( in_channels=stage_out_channels[-2], out_channels=stage_out_channels[-1], kernel_size=1, stride=1, padding=0, - act=act) + activation_layer=activation_layer) # 4. pool if with_pool: self._pool2d_avg = AdaptiveAvgPool2D(1) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 2d60fd4561..e4dd4c797f 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1335,13 +1335,13 @@ class ConvNormActivation(Sequential): Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block - kernel_size: (int, optional): Size of the convolving kernel. Default: 3 - stride (int, optional): Stride of the convolution. Default: 1 - padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, + kernel_size: (int|list|tuple, optional): Size of the convolving kernel. Default: 3 + stride (int|list|tuple, optional): Stride of the convolution. Default: 1 + padding (int|str|tuple|list, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., paddle.nn.Layer], optional): Norm layer that will be stacked on top of the convolutiuon layer. - If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2d`` + If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2D`` activation_layer (Callable[..., paddle.nn.Layer], optional): Activation function which will be stacked on top of the normalization layer (if not ``None``), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``paddle.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 -- GitLab