From 40ac92b2d4b72ac61dcae2278e2d119e9ccab284 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Tue, 26 Apr 2022 17:05:39 +0800 Subject: [PATCH] [cherry-pick] refactor vision models (#42252) * reuse ConvNormActivation in some vision models (#40431) * reuse ConvNormActivation in some vision models * reimplement ResNeXt based on ResNet (#40588) * refactor resnext --- python/paddle/vision/__init__.py | 13 +- python/paddle/vision/models/__init__.py | 26 +- python/paddle/vision/models/inceptionv3.py | 477 ++++++++++---------- python/paddle/vision/models/mobilenetv1.py | 56 +-- python/paddle/vision/models/mobilenetv2.py | 89 ++-- python/paddle/vision/models/resnet.py | 258 ++++++++++- python/paddle/vision/models/resnext.py | 364 --------------- python/paddle/vision/models/shufflenetv2.py | 124 +++-- python/paddle/vision/ops.py | 8 +- 9 files changed, 628 insertions(+), 787 deletions(-) delete mode 100644 python/paddle/vision/models/resnext.py diff --git a/python/paddle/vision/__init__.py b/python/paddle/vision/__init__.py index 3749e0f64f..2f0052537e 100644 --- a/python/paddle/vision/__init__.py +++ b/python/paddle/vision/__init__.py @@ -34,6 +34,12 @@ from .models import resnet34 # noqa: F401 from .models import resnet50 # noqa: F401 from .models import resnet101 # noqa: F401 from .models import resnet152 # noqa: F401 +from .models import resnext50_32x4d # noqa: F401 +from .models import resnext50_64x4d # noqa: F401 +from .models import resnext101_32x4d # noqa: F401 +from .models import resnext101_64x4d # noqa: F401 +from .models import resnext152_32x4d # noqa: F401 +from .models import resnext152_64x4d # noqa: F401 from .models import wide_resnet50_2 # noqa: F401 from .models import wide_resnet101_2 # noqa: F401 from .models import MobileNetV1 # noqa: F401 @@ -61,13 +67,6 @@ from .models import densenet201 # noqa: F401 from .models import densenet264 # noqa: F401 from .models import AlexNet # noqa: F401 from .models import alexnet # noqa: F401 -from .models import ResNeXt # noqa: F401 -from .models import resnext50_32x4d # noqa: F401 -from .models import resnext50_64x4d # noqa: F401 -from .models import resnext101_32x4d # noqa: F401 -from .models import resnext101_64x4d # noqa: F401 -from .models import resnext152_32x4d # noqa: F401 -from .models import resnext152_64x4d # noqa: F401 from .models import InceptionV3 # noqa: F401 from .models import inception_v3 # noqa: F401 from .models import GoogLeNet # noqa: F401 diff --git a/python/paddle/vision/models/__init__.py b/python/paddle/vision/models/__init__.py index 5ff3562e56..85ff5f85df 100644 --- a/python/paddle/vision/models/__init__.py +++ b/python/paddle/vision/models/__init__.py @@ -18,6 +18,12 @@ from .resnet import resnet34 # noqa: F401 from .resnet import resnet50 # noqa: F401 from .resnet import resnet101 # noqa: F401 from .resnet import resnet152 # noqa: F401 +from .resnet import resnext50_32x4d # noqa: F401 +from .resnet import resnext50_64x4d # noqa: F401 +from .resnet import resnext101_32x4d # noqa: F401 +from .resnet import resnext101_64x4d # noqa: F401 +from .resnet import resnext152_32x4d # noqa: F401 +from .resnet import resnext152_64x4d # noqa: F401 from .resnet import wide_resnet50_2 # noqa: F401 from .resnet import wide_resnet101_2 # noqa: F401 from .mobilenetv1 import MobileNetV1 # noqa: F401 @@ -42,13 +48,6 @@ from .densenet import densenet201 # noqa: F401 from .densenet import densenet264 # noqa: F401 from .alexnet import AlexNet # noqa: F401 from .alexnet import alexnet # noqa: F401 -from .resnext import ResNeXt # noqa: F401 -from .resnext import resnext50_32x4d # noqa: F401 -from .resnext import resnext50_64x4d # noqa: F401 -from .resnext import resnext101_32x4d # noqa: F401 -from .resnext import resnext101_64x4d # noqa: F401 -from .resnext import resnext152_32x4d # noqa: F401 -from .resnext import resnext152_64x4d # noqa: F401 from .inceptionv3 import InceptionV3 # noqa: F401 from .inceptionv3 import inception_v3 # noqa: F401 from .squeezenet import SqueezeNet # noqa: F401 @@ -72,6 +71,12 @@ __all__ = [ #noqa 'resnet50', 'resnet101', 'resnet152', + 'resnext50_32x4d', + 'resnext50_64x4d', + 'resnext101_32x4d', + 'resnext101_64x4d', + 'resnext152_32x4d', + 'resnext152_64x4d', 'wide_resnet50_2', 'wide_resnet101_2', 'VGG', @@ -96,13 +101,6 @@ __all__ = [ #noqa 'densenet264', 'AlexNet', 'alexnet', - 'ResNeXt', - 'resnext50_32x4d', - 'resnext50_64x4d', - 'resnext101_32x4d', - 'resnext101_64x4d', - 'resnext152_32x4d', - 'resnext152_64x4d', 'InceptionV3', 'inception_v3', 'SqueezeNet', diff --git a/python/paddle/vision/models/inceptionv3.py b/python/paddle/vision/models/inceptionv3.py index 9e8a8b8146..27650dbe09 100644 --- a/python/paddle/vision/models/inceptionv3.py +++ b/python/paddle/vision/models/inceptionv3.py @@ -19,75 +19,60 @@ from __future__ import print_function import math import paddle import paddle.nn as nn -from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import Linear, Dropout from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn.initializer import Uniform from paddle.fluid.param_attr import ParamAttr from paddle.utils.download import get_weights_path_from_url +from ..ops import ConvNormActivation __all__ = [] model_urls = { "inception_v3": - ("https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams", - "e4d0905a818f6bb7946e881777a8a935") + ("https://paddle-hapi.bj.bcebos.com/models/inception_v3.pdparams", + "649a4547c3243e8b59c656f41fe330b8") } -class ConvBNLayer(nn.Layer): - def __init__(self, - num_channels, - num_filters, - filter_size, - stride=1, - padding=0, - groups=1, - act="relu"): - super().__init__() - self.act = act - self.conv = Conv2D( - in_channels=num_channels, - out_channels=num_filters, - kernel_size=filter_size, - stride=stride, - padding=padding, - groups=groups, - bias_attr=False) - self.bn = BatchNorm(num_filters) - self.relu = nn.ReLU() - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - if self.act: - x = self.relu(x) - return x - - class InceptionStem(nn.Layer): def __init__(self): super().__init__() - self.conv_1a_3x3 = ConvBNLayer( - num_channels=3, num_filters=32, filter_size=3, stride=2, act="relu") - self.conv_2a_3x3 = ConvBNLayer( - num_channels=32, - num_filters=32, - filter_size=3, + self.conv_1a_3x3 = ConvNormActivation( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=2, + padding=0, + activation_layer=nn.ReLU) + self.conv_2a_3x3 = ConvNormActivation( + in_channels=32, + out_channels=32, + kernel_size=3, stride=1, - act="relu") - self.conv_2b_3x3 = ConvBNLayer( - num_channels=32, - num_filters=64, - filter_size=3, + padding=0, + activation_layer=nn.ReLU) + self.conv_2b_3x3 = ConvNormActivation( + in_channels=32, + out_channels=64, + kernel_size=3, padding=1, - act="relu") + activation_layer=nn.ReLU) self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0) - self.conv_3b_1x1 = ConvBNLayer( - num_channels=64, num_filters=80, filter_size=1, act="relu") - self.conv_4a_3x3 = ConvBNLayer( - num_channels=80, num_filters=192, filter_size=3, act="relu") + self.conv_3b_1x1 = ConvNormActivation( + in_channels=64, + out_channels=80, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.conv_4a_3x3 = ConvNormActivation( + in_channels=80, + out_channels=192, + kernel_size=3, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): x = self.conv_1a_3x3(x) @@ -103,47 +88,53 @@ class InceptionStem(nn.Layer): class InceptionA(nn.Layer): def __init__(self, num_channels, pool_features): super().__init__() - self.branch1x1 = ConvBNLayer( - num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch5x5_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=48, - filter_size=1, - act="relu") - self.branch5x5_2 = ConvBNLayer( - num_channels=48, - num_filters=64, - filter_size=5, + self.branch1x1 = ConvNormActivation( + in_channels=num_channels, + out_channels=64, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + + self.branch5x5_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=48, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch5x5_2 = ConvNormActivation( + in_channels=48, + out_channels=64, + kernel_size=5, padding=2, - act="relu") - - self.branch3x3dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer( - num_channels=64, - num_filters=96, - filter_size=3, + activation_layer=nn.ReLU) + + self.branch3x3dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=64, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3dbl_2 = ConvNormActivation( + in_channels=64, + out_channels=96, + kernel_size=3, padding=1, - act="relu") - self.branch3x3dbl_3 = ConvBNLayer( - num_channels=96, - num_filters=96, - filter_size=3, + activation_layer=nn.ReLU) + self.branch3x3dbl_3 = ConvNormActivation( + in_channels=96, + out_channels=96, + kernel_size=3, padding=1, - act="relu") + activation_layer=nn.ReLU) + self.branch_pool = AvgPool2D( kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=pool_features, - filter_size=1, - act="relu") + self.branch_pool_conv = ConvNormActivation( + in_channels=num_channels, + out_channels=pool_features, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -164,29 +155,34 @@ class InceptionA(nn.Layer): class InceptionB(nn.Layer): def __init__(self, num_channels): super().__init__() - self.branch3x3 = ConvBNLayer( - num_channels=num_channels, - num_filters=384, - filter_size=3, + self.branch3x3 = ConvNormActivation( + in_channels=num_channels, + out_channels=384, + kernel_size=3, stride=2, - act="relu") - self.branch3x3dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer( - num_channels=64, - num_filters=96, - filter_size=3, + padding=0, + activation_layer=nn.ReLU) + + self.branch3x3dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=64, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3dbl_2 = ConvNormActivation( + in_channels=64, + out_channels=96, + kernel_size=3, padding=1, - act="relu") - self.branch3x3dbl_3 = ConvBNLayer( - num_channels=96, - num_filters=96, - filter_size=3, + activation_layer=nn.ReLU) + self.branch3x3dbl_3 = ConvNormActivation( + in_channels=96, + out_channels=96, + kernel_size=3, stride=2, - act="relu") + padding=0, + activation_layer=nn.ReLU) + self.branch_pool = MaxPool2D(kernel_size=3, stride=2) def forward(self, x): @@ -206,70 +202,74 @@ class InceptionB(nn.Layer): class InceptionC(nn.Layer): def __init__(self, num_channels, channels_7x7): super().__init__() - self.branch1x1 = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - - self.branch7x7_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=channels_7x7, - filter_size=1, + self.branch1x1 = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + + self.branch7x7_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=channels_7x7, + kernel_size=1, stride=1, - act="relu") - self.branch7x7_2 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(1, 7), + padding=0, + activation_layer=nn.ReLU) + self.branch7x7_2 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(1, 7), stride=1, padding=(0, 3), - act="relu") - self.branch7x7_3 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=192, - filter_size=(7, 1), + activation_layer=nn.ReLU) + self.branch7x7_3 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=192, + kernel_size=(7, 1), stride=1, padding=(3, 0), - act="relu") - - self.branch7x7dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=channels_7x7, - filter_size=1, - act="relu") - self.branch7x7dbl_2 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(7, 1), + activation_layer=nn.ReLU) + + self.branch7x7dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=channels_7x7, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch7x7dbl_2 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(7, 1), padding=(3, 0), - act="relu") - self.branch7x7dbl_3 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(1, 7), + activation_layer=nn.ReLU) + self.branch7x7dbl_3 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(1, 7), padding=(0, 3), - act="relu") - self.branch7x7dbl_4 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(7, 1), + activation_layer=nn.ReLU) + self.branch7x7dbl_4 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=channels_7x7, + kernel_size=(7, 1), padding=(3, 0), - act="relu") - self.branch7x7dbl_5 = ConvBNLayer( - num_channels=channels_7x7, - num_filters=192, - filter_size=(1, 7), + activation_layer=nn.ReLU) + self.branch7x7dbl_5 = ConvNormActivation( + in_channels=channels_7x7, + out_channels=192, + kernel_size=(1, 7), padding=(0, 3), - act="relu") + activation_layer=nn.ReLU) self.branch_pool = AvgPool2D( kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") + self.branch_pool_conv = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -296,40 +296,46 @@ class InceptionC(nn.Layer): class InceptionD(nn.Layer): def __init__(self, num_channels): super().__init__() - self.branch3x3_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - self.branch3x3_2 = ConvBNLayer( - num_channels=192, - num_filters=320, - filter_size=3, + self.branch3x3_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3_2 = ConvNormActivation( + in_channels=192, + out_channels=320, + kernel_size=3, stride=2, - act="relu") - self.branch7x7x3_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - self.branch7x7x3_2 = ConvBNLayer( - num_channels=192, - num_filters=192, - filter_size=(1, 7), + padding=0, + activation_layer=nn.ReLU) + + self.branch7x7x3_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch7x7x3_2 = ConvNormActivation( + in_channels=192, + out_channels=192, + kernel_size=(1, 7), padding=(0, 3), - act="relu") - self.branch7x7x3_3 = ConvBNLayer( - num_channels=192, - num_filters=192, - filter_size=(7, 1), + activation_layer=nn.ReLU) + self.branch7x7x3_3 = ConvNormActivation( + in_channels=192, + out_channels=192, + kernel_size=(7, 1), padding=(3, 0), - act="relu") - self.branch7x7x3_4 = ConvBNLayer( - num_channels=192, - num_filters=192, - filter_size=3, + activation_layer=nn.ReLU) + self.branch7x7x3_4 = ConvNormActivation( + in_channels=192, + out_channels=192, + kernel_size=3, stride=2, - act="relu") + padding=0, + activation_layer=nn.ReLU) + self.branch_pool = MaxPool2D(kernel_size=3, stride=2) def forward(self, x): @@ -350,59 +356,64 @@ class InceptionD(nn.Layer): class InceptionE(nn.Layer): def __init__(self, num_channels): super().__init__() - self.branch1x1 = ConvBNLayer( - num_channels=num_channels, - num_filters=320, - filter_size=1, - act="relu") - self.branch3x3_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=384, - filter_size=1, - act="relu") - self.branch3x3_2a = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(1, 3), + self.branch1x1 = ConvNormActivation( + in_channels=num_channels, + out_channels=320, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=384, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3_2a = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(1, 3), padding=(0, 1), - act="relu") - self.branch3x3_2b = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(3, 1), + activation_layer=nn.ReLU) + self.branch3x3_2b = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(3, 1), padding=(1, 0), - act="relu") - - self.branch3x3dbl_1 = ConvBNLayer( - num_channels=num_channels, - num_filters=448, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer( - num_channels=448, - num_filters=384, - filter_size=3, + activation_layer=nn.ReLU) + + self.branch3x3dbl_1 = ConvNormActivation( + in_channels=num_channels, + out_channels=448, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) + self.branch3x3dbl_2 = ConvNormActivation( + in_channels=448, + out_channels=384, + kernel_size=3, padding=1, - act="relu") - self.branch3x3dbl_3a = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(1, 3), + activation_layer=nn.ReLU) + self.branch3x3dbl_3a = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(1, 3), padding=(0, 1), - act="relu") - self.branch3x3dbl_3b = ConvBNLayer( - num_channels=384, - num_filters=384, - filter_size=(3, 1), + activation_layer=nn.ReLU) + self.branch3x3dbl_3b = ConvNormActivation( + in_channels=384, + out_channels=384, + kernel_size=(3, 1), padding=(1, 0), - act="relu") + activation_layer=nn.ReLU) + self.branch_pool = AvgPool2D( kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") + self.branch_pool_conv = ConvNormActivation( + in_channels=num_channels, + out_channels=192, + kernel_size=1, + padding=0, + activation_layer=nn.ReLU) def forward(self, x): branch1x1 = self.branch1x1(x) diff --git a/python/paddle/vision/models/mobilenetv1.py b/python/paddle/vision/models/mobilenetv1.py index 671a2cd8df..6d8d96952f 100644 --- a/python/paddle/vision/models/mobilenetv1.py +++ b/python/paddle/vision/models/mobilenetv1.py @@ -16,59 +16,31 @@ import paddle import paddle.nn as nn from paddle.utils.download import get_weights_path_from_url +from ..ops import ConvNormActivation __all__ = [] model_urls = { 'mobilenetv1_1.0': - ('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams', - '42a154c2f26f86e7457d6daded114e8c') + ('https://paddle-hapi.bj.bcebos.com/models/mobilenetv1_1.0.pdparams', + '3033ab1975b1670bef51545feb65fc45') } -class ConvBNLayer(nn.Layer): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - num_groups=1): - super(ConvBNLayer, self).__init__() - - self._conv = nn.Conv2D( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - groups=num_groups, - bias_attr=False) - - self._norm_layer = nn.BatchNorm2D(out_channels) - self._act = nn.ReLU() - - def forward(self, x): - x = self._conv(x) - x = self._norm_layer(x) - x = self._act(x) - return x - - class DepthwiseSeparable(nn.Layer): def __init__(self, in_channels, out_channels1, out_channels2, num_groups, stride, scale): super(DepthwiseSeparable, self).__init__() - self._depthwise_conv = ConvBNLayer( + self._depthwise_conv = ConvNormActivation( in_channels, int(out_channels1 * scale), kernel_size=3, stride=stride, padding=1, - num_groups=int(num_groups * scale)) + groups=int(num_groups * scale)) - self._pointwise_conv = ConvBNLayer( + self._pointwise_conv = ConvNormActivation( int(out_channels1 * scale), int(out_channels2 * scale), kernel_size=1, @@ -94,9 +66,15 @@ class MobileNetV1(nn.Layer): Examples: .. code-block:: python + import paddle from paddle.vision.models import MobileNetV1 model = MobileNetV1() + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) """ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): @@ -106,7 +84,7 @@ class MobileNetV1(nn.Layer): self.num_classes = num_classes self.with_pool = with_pool - self.conv1 = ConvBNLayer( + self.conv1 = ConvNormActivation( in_channels=3, out_channels=int(32 * scale), kernel_size=3, @@ -257,6 +235,7 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v1 # build model @@ -266,7 +245,12 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): # model = mobilenet_v1(pretrained=True) # build mobilenet v1 with scale=0.5 - model = mobilenet_v1(scale=0.5) + model_scale = mobilenet_v1(scale=0.5) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) """ model = _mobilenet( 'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs) diff --git a/python/paddle/vision/models/mobilenetv2.py b/python/paddle/vision/models/mobilenetv2.py index 6c486037c7..9791462610 100644 --- a/python/paddle/vision/models/mobilenetv2.py +++ b/python/paddle/vision/models/mobilenetv2.py @@ -17,6 +17,7 @@ import paddle.nn as nn from paddle.utils.download import get_weights_path_from_url from .utils import _make_divisible +from ..ops import ConvNormActivation __all__ = [] @@ -27,29 +28,6 @@ model_urls = { } -class ConvBNReLU(nn.Sequential): - def __init__(self, - in_planes, - out_planes, - kernel_size=3, - stride=1, - groups=1, - norm_layer=nn.BatchNorm2D): - padding = (kernel_size - 1) // 2 - - super(ConvBNReLU, self).__init__( - nn.Conv2D( - in_planes, - out_planes, - kernel_size, - stride, - padding, - groups=groups, - bias_attr=False), - norm_layer(out_planes), - nn.ReLU6()) - - class InvertedResidual(nn.Layer): def __init__(self, inp, @@ -67,15 +45,20 @@ class InvertedResidual(nn.Layer): layers = [] if expand_ratio != 1: layers.append( - ConvBNReLU( - inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) + ConvNormActivation( + inp, + hidden_dim, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.ReLU6)) layers.extend([ - ConvBNReLU( + ConvNormActivation( hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, - norm_layer=norm_layer), + norm_layer=norm_layer, + activation_layer=nn.ReLU6), nn.Conv2D( hidden_dim, oup, 1, 1, 0, bias_attr=False), norm_layer(oup), @@ -90,23 +73,30 @@ class InvertedResidual(nn.Layer): class MobileNetV2(nn.Layer): - def __init__(self, scale=1.0, num_classes=1000, with_pool=True): - """MobileNetV2 model from - `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + """MobileNetV2 model from + `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + + Args: + scale (float): scale of channels in each layer. Default: 1.0. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. - Args: - scale (float): scale of channels in each layer. Default: 1.0. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer - will not be defined. Default: 1000. - with_pool (bool): use pool before the last fc layer or not. Default: True. + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import MobileNetV2 - Examples: - .. code-block:: python + model = MobileNetV2() - from paddle.vision.models import MobileNetV2 + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + """ - model = MobileNetV2() - """ + def __init__(self, scale=1.0, num_classes=1000, with_pool=True): super(MobileNetV2, self).__init__() self.num_classes = num_classes self.with_pool = with_pool @@ -130,8 +120,12 @@ class MobileNetV2(nn.Layer): self.last_channel = _make_divisible(last_channel * max(1.0, scale), round_nearest) features = [ - ConvBNReLU( - 3, input_channel, stride=2, norm_layer=norm_layer) + ConvNormActivation( + 3, + input_channel, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.ReLU6) ] for t, c, n, s in inverted_residual_setting: @@ -148,11 +142,12 @@ class MobileNetV2(nn.Layer): input_channel = output_channel features.append( - ConvBNReLU( + ConvNormActivation( input_channel, self.last_channel, kernel_size=1, - norm_layer=norm_layer)) + norm_layer=norm_layer, + activation_layer=nn.ReLU6)) self.features = nn.Sequential(*features) @@ -199,6 +194,7 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v2 # build model @@ -209,6 +205,11 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): # build mobilenet v2 with scale=0.5 model = mobilenet_v2(scale=0.5) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) """ model = _mobilenet( 'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs) diff --git a/python/paddle/vision/models/resnet.py b/python/paddle/vision/models/resnet.py index 5921ae10ee..27536b6a9c 100644 --- a/python/paddle/vision/models/resnet.py +++ b/python/paddle/vision/models/resnet.py @@ -33,12 +33,30 @@ model_urls = { '02f35f034ca3858e1e54d4036443c92d'), 'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', '7ad16a2f1e7333859ff986138630fd7a'), - 'wide_resnet50_2': - ('https://paddle-hapi.bj.bcebos.com/models/wide_resnet50_2.pdparams', - '0282f804d73debdab289bd9fea3fa6dc'), - 'wide_resnet101_2': - ('https://paddle-hapi.bj.bcebos.com/models/wide_resnet101_2.pdparams', - 'd4360a2d23657f059216f5d5a1a9ac93'), + 'resnext50_32x4d': + ('https://paddle-hapi.bj.bcebos.com/models/resnext50_32x4d.pdparams', + 'dc47483169be7d6f018fcbb7baf8775d'), + "resnext50_64x4d": + ('https://paddle-hapi.bj.bcebos.com/models/resnext50_64x4d.pdparams', + '063d4b483e12b06388529450ad7576db'), + 'resnext101_32x4d': ( + 'https://paddle-hapi.bj.bcebos.com/models/resnext101_32x4d.pdparams', + '967b090039f9de2c8d06fe994fb9095f'), + 'resnext101_64x4d': ( + 'https://paddle-hapi.bj.bcebos.com/models/resnext101_64x4d.pdparams', + '98e04e7ca616a066699230d769d03008'), + 'resnext152_32x4d': ( + 'https://paddle-hapi.bj.bcebos.com/models/resnext152_32x4d.pdparams', + '18ff0beee21f2efc99c4b31786107121'), + 'resnext152_64x4d': ( + 'https://paddle-hapi.bj.bcebos.com/models/resnext152_64x4d.pdparams', + '77c4af00ca42c405fa7f841841959379'), + 'wide_resnet50_2': ( + 'https://paddle-hapi.bj.bcebos.com/models/wide_resnet50_2.pdparams', + '0282f804d73debdab289bd9fea3fa6dc'), + 'wide_resnet101_2': ( + 'https://paddle-hapi.bj.bcebos.com/models/wide_resnet101_2.pdparams', + 'd4360a2d23657f059216f5d5a1a9ac93'), } @@ -158,11 +176,12 @@ class ResNet(nn.Layer): Args: Block (BasicBlock|BottleneckBlock): block module of model. - depth (int): layers of resnet, default: 50. - width (int): base width of resnet, default: 64. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + depth (int, optional): layers of resnet, Default: 50. + width (int, optional): base width per convolution group for each convolution block, Default: 64. + num_classes (int, optional): output dim of last fc layer. If num_classes <=0, last fc layer will not be defined. Default: 1000. - with_pool (bool): use pool before the last fc layer or not. Default: True. + with_pool (bool, optional): use pool before the last fc layer or not. Default: True. + groups (int, optional): number of groups for each convolution block, Default: 1. Examples: .. code-block:: python @@ -171,16 +190,23 @@ class ResNet(nn.Layer): from paddle.vision.models import ResNet from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + # build ResNet with 18 layers + resnet18 = ResNet(BasicBlock, 18) + + # build ResNet with 50 layers resnet50 = ResNet(BottleneckBlock, 50) + # build Wide ResNet model wide_resnet50_2 = ResNet(BottleneckBlock, 50, width=64*2) - resnet18 = ResNet(BasicBlock, 18) + # build ResNeXt model + resnext50_32x4d = ResNet(BottleneckBlock, 50, width=4, groups=32) x = paddle.rand([1, 3, 224, 224]) out = resnet18(x) print(out.shape) + # [1, 1000] """ @@ -189,7 +215,8 @@ class ResNet(nn.Layer): depth=50, width=64, num_classes=1000, - with_pool=True): + with_pool=True, + groups=1): super(ResNet, self).__init__() layer_cfg = { 18: [2, 2, 2, 2], @@ -199,7 +226,7 @@ class ResNet(nn.Layer): 152: [3, 8, 36, 3] } layers = layer_cfg[depth] - self.groups = 1 + self.groups = groups self.base_width = width self.num_classes = num_classes self.with_pool = with_pool @@ -300,7 +327,7 @@ def resnet18(pretrained=False, **kwargs): `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -318,6 +345,7 @@ def resnet18(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) @@ -327,7 +355,7 @@ def resnet34(pretrained=False, **kwargs): `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -345,6 +373,7 @@ def resnet34(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) @@ -354,7 +383,7 @@ def resnet50(pretrained=False, **kwargs): `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -372,6 +401,7 @@ def resnet50(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) @@ -381,7 +411,7 @@ def resnet101(pretrained=False, **kwargs): `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -399,6 +429,7 @@ def resnet101(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) @@ -408,7 +439,7 @@ def resnet152(pretrained=False, **kwargs): `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -426,16 +457,201 @@ def resnet152(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) +def resnext50_32x4d(pretrained=False, **kwargs): + """ResNeXt-50 32x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_ + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import resnext50_32x4d + + # build model + model = resnext50_32x4d() + + # build model and load imagenet pretrained weight + # model = resnext50_32x4d(pretrained=True) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 32 + kwargs['width'] = 4 + return _resnet('resnext50_32x4d', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnext50_64x4d(pretrained=False, **kwargs): + """ResNeXt-50 64x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_ + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import resnext50_64x4d + + # build model + model = resnext50_64x4d() + + # build model and load imagenet pretrained weight + # model = resnext50_64x4d(pretrained=True) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 64 + kwargs['width'] = 4 + return _resnet('resnext50_64x4d', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnext101_32x4d(pretrained=False, **kwargs): + """ResNeXt-101 32x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_ + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import resnext101_32x4d + + # build model + model = resnext101_32x4d() + + # build model and load imagenet pretrained weight + # model = resnext101_32x4d(pretrained=True) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 32 + kwargs['width'] = 4 + return _resnet('resnext101_32x4d', BottleneckBlock, 101, pretrained, + **kwargs) + + +def resnext101_64x4d(pretrained=False, **kwargs): + """ResNeXt-101 64x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_ + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import resnext101_64x4d + + # build model + model = resnext101_64x4d() + + # build model and load imagenet pretrained weight + # model = resnext101_64x4d(pretrained=True) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 64 + kwargs['width'] = 4 + return _resnet('resnext101_64x4d', BottleneckBlock, 101, pretrained, + **kwargs) + + +def resnext152_32x4d(pretrained=False, **kwargs): + """ResNeXt-152 32x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_ + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import resnext152_32x4d + + # build model + model = resnext152_32x4d() + + # build model and load imagenet pretrained weight + # model = resnext152_32x4d(pretrained=True) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 32 + kwargs['width'] = 4 + return _resnet('resnext152_32x4d', BottleneckBlock, 152, pretrained, + **kwargs) + + +def resnext152_64x4d(pretrained=False, **kwargs): + """ResNeXt-152 64x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_ + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.models import resnext152_64x4d + + # build model + model = resnext152_64x4d() + + # build model and load imagenet pretrained weight + # model = resnext152_64x4d(pretrained=True) + + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 64 + kwargs['width'] = 4 + return _resnet('resnext152_64x4d', BottleneckBlock, 152, pretrained, + **kwargs) + + def wide_resnet50_2(pretrained=False, **kwargs): """Wide ResNet-50-2 model from `"Wide Residual Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -453,6 +669,7 @@ def wide_resnet50_2(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ kwargs['width'] = 64 * 2 return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs) @@ -463,7 +680,7 @@ def wide_resnet101_2(pretrained=False, **kwargs): `"Wide Residual Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False. Examples: .. code-block:: python @@ -481,6 +698,7 @@ def wide_resnet101_2(pretrained=False, **kwargs): out = model(x) print(out.shape) + # [1, 1000] """ kwargs['width'] = 64 * 2 return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained, diff --git a/python/paddle/vision/models/resnext.py b/python/paddle/vision/models/resnext.py deleted file mode 100644 index 2e1073c8ac..0000000000 --- a/python/paddle/vision/models/resnext.py +++ /dev/null @@ -1,364 +0,0 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -from paddle.fluid.param_attr import ParamAttr -from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear, MaxPool2D -from paddle.nn.initializer import Uniform -from paddle.utils.download import get_weights_path_from_url - -__all__ = [] - -model_urls = { - 'resnext50_32x4d': - ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt50_32x4d_pretrained.pdparams', - 'bf04add2f7fd22efcbe91511bcd1eebe'), - "resnext50_64x4d": - ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt50_64x4d_pretrained.pdparams', - '46307df0e2d6d41d3b1c1d22b00abc69'), - 'resnext101_32x4d': - ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x4d_pretrained.pdparams', - '078ca145b3bea964ba0544303a43c36d'), - 'resnext101_64x4d': - ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_64x4d_pretrained.pdparams', - '4edc0eb32d3cc5d80eff7cab32cd5c64'), - 'resnext152_32x4d': - ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt152_32x4d_pretrained.pdparams', - '7971cc994d459af167c502366f866378'), - 'resnext152_64x4d': - ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt152_64x4d_pretrained.pdparams', - '836943f03709efec364d486c57d132de'), -} - - -class ConvBNLayer(nn.Layer): - def __init__(self, - num_channels, - num_filters, - filter_size, - stride=1, - groups=1, - act=None): - super(ConvBNLayer, self).__init__() - self._conv = Conv2D( - in_channels=num_channels, - out_channels=num_filters, - kernel_size=filter_size, - stride=stride, - padding=(filter_size - 1) // 2, - groups=groups, - bias_attr=False) - self._batch_norm = BatchNorm(num_filters, act=act) - - def forward(self, inputs): - x = self._conv(inputs) - x = self._batch_norm(x) - return x - - -class BottleneckBlock(nn.Layer): - def __init__(self, - num_channels, - num_filters, - stride, - cardinality, - shortcut=True): - super(BottleneckBlock, self).__init__() - self.conv0 = ConvBNLayer( - num_channels=num_channels, - num_filters=num_filters, - filter_size=1, - act='relu') - self.conv1 = ConvBNLayer( - num_channels=num_filters, - num_filters=num_filters, - filter_size=3, - groups=cardinality, - stride=stride, - act='relu') - self.conv2 = ConvBNLayer( - num_channels=num_filters, - num_filters=num_filters * 2 if cardinality == 32 else num_filters, - filter_size=1, - act=None) - - if not shortcut: - self.short = ConvBNLayer( - num_channels=num_channels, - num_filters=num_filters * 2 - if cardinality == 32 else num_filters, - filter_size=1, - stride=stride) - - self.shortcut = shortcut - - def forward(self, inputs): - x = self.conv0(inputs) - conv1 = self.conv1(x) - conv2 = self.conv2(conv1) - - if self.shortcut: - short = inputs - else: - short = self.short(inputs) - - x = paddle.add(x=short, y=conv2) - x = F.relu(x) - return x - - -class ResNeXt(nn.Layer): - """ResNeXt model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - depth (int, optional): depth of resnext. Default: 50. - cardinality (int, optional): cardinality of resnext. Default: 32. - num_classes (int, optional): output dim of last fc layer. If num_classes <=0, last fc layer - will not be defined. Default: 1000. - with_pool (bool, optional): use pool before the last fc layer or not. Default: True. - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import ResNeXt - - resnext50_32x4d = ResNeXt(depth=50, cardinality=32) - - """ - - def __init__(self, - depth=50, - cardinality=32, - num_classes=1000, - with_pool=True): - super(ResNeXt, self).__init__() - - self.depth = depth - self.cardinality = cardinality - self.num_classes = num_classes - self.with_pool = with_pool - - supported_depth = [50, 101, 152] - assert depth in supported_depth, \ - "supported layers are {} but input layer is {}".format( - supported_depth, depth) - supported_cardinality = [32, 64] - assert cardinality in supported_cardinality, \ - "supported cardinality is {} but input cardinality is {}" \ - .format(supported_cardinality, cardinality) - layer_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]} - layers = layer_cfg[depth] - num_channels = [64, 256, 512, 1024] - num_filters = [128, 256, 512, - 1024] if cardinality == 32 else [256, 512, 1024, 2048] - - self.conv = ConvBNLayer( - num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu') - self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1) - - self.block_list = [] - for block in range(len(layers)): - shortcut = False - for i in range(layers[block]): - bottleneck_block = self.add_sublayer( - 'bb_%d_%d' % (block, i), - BottleneckBlock( - num_channels=num_channels[block] if i == 0 else - num_filters[block] * int(64 // self.cardinality), - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - cardinality=self.cardinality, - shortcut=shortcut)) - self.block_list.append(bottleneck_block) - shortcut = True - - if with_pool: - self.pool2d_avg = AdaptiveAvgPool2D(1) - - if num_classes > 0: - self.pool2d_avg_channels = num_channels[-1] * 2 - stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0) - self.out = Linear( - self.pool2d_avg_channels, - num_classes, - weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) - - def forward(self, inputs): - with paddle.static.amp.fp16_guard(): - x = self.conv(inputs) - x = self.pool2d_max(x) - for block in self.block_list: - x = block(x) - if self.with_pool: - x = self.pool2d_avg(x) - if self.num_classes > 0: - x = paddle.reshape(x, shape=[-1, self.pool2d_avg_channels]) - x = self.out(x) - return x - - -def _resnext(arch, depth, cardinality, pretrained, **kwargs): - model = ResNeXt(depth=depth, cardinality=cardinality, **kwargs) - if pretrained: - assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch) - weight_path = get_weights_path_from_url(model_urls[arch][0], - model_urls[arch][1]) - - param = paddle.load(weight_path) - model.set_dict(param) - - return model - - -def resnext50_32x4d(pretrained=False, **kwargs): - """ResNeXt-50 32x4d model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import resnext50_32x4d - - # build model - model = resnext50_32x4d() - - # build model and load imagenet pretrained weight - # model = resnext50_32x4d(pretrained=True) - """ - return _resnext('resnext50_32x4d', 50, 32, pretrained, **kwargs) - - -def resnext50_64x4d(pretrained=False, **kwargs): - """ResNeXt-50 64x4d model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import resnext50_64x4d - - # build model - model = resnext50_64x4d() - - # build model and load imagenet pretrained weight - # model = resnext50_64x4d(pretrained=True) - """ - return _resnext('resnext50_64x4d', 50, 64, pretrained, **kwargs) - - -def resnext101_32x4d(pretrained=False, **kwargs): - """ResNeXt-101 32x4d model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import resnext101_32x4d - - # build model - model = resnext101_32x4d() - - # build model and load imagenet pretrained weight - # model = resnext101_32x4d(pretrained=True) - """ - return _resnext('resnext101_32x4d', 101, 32, pretrained, **kwargs) - - -def resnext101_64x4d(pretrained=False, **kwargs): - """ResNeXt-101 64x4d model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import resnext101_64x4d - - # build model - model = resnext101_64x4d() - - # build model and load imagenet pretrained weight - # model = resnext101_64x4d(pretrained=True) - """ - return _resnext('resnext101_64x4d', 101, 64, pretrained, **kwargs) - - -def resnext152_32x4d(pretrained=False, **kwargs): - """ResNeXt-152 32x4d model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import resnext152_32x4d - - # build model - model = resnext152_32x4d() - - # build model and load imagenet pretrained weight - # model = resnext152_32x4d(pretrained=True) - """ - return _resnext('resnext152_32x4d', 152, 32, pretrained, **kwargs) - - -def resnext152_64x4d(pretrained=False, **kwargs): - """ResNeXt-152 64x4d model from - `"Aggregated Residual Transformations for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - import paddle - from paddle.vision.models import resnext152_64x4d - - # build model - model = resnext152_64x4d() - - # build model and load imagenet pretrained weight - # model = resnext152_64x4d(pretrained=True) - """ - return _resnext('resnext152_64x4d', 152, 64, pretrained, **kwargs) diff --git a/python/paddle/vision/models/shufflenetv2.py b/python/paddle/vision/models/shufflenetv2.py index 041f3fc749..90e967ee22 100644 --- a/python/paddle/vision/models/shufflenetv2.py +++ b/python/paddle/vision/models/shufflenetv2.py @@ -18,37 +18,50 @@ from __future__ import print_function import paddle import paddle.nn as nn -from paddle.fluid.param_attr import ParamAttr -from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear, MaxPool2D +from paddle.nn import AdaptiveAvgPool2D, Linear, MaxPool2D from paddle.utils.download import get_weights_path_from_url +from ..ops import ConvNormActivation + __all__ = [] model_urls = { "shufflenet_v2_x0_25": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_25_pretrained.pdparams", - "e753404cbd95027759c5f56ecd6c9c4b", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_25.pdparams", + "1e509b4c140eeb096bb16e214796d03b", ), "shufflenet_v2_x0_33": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_33_pretrained.pdparams", - "776e3cf9a4923abdfce789c45b8fe1f2", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_33.pdparams", + "3d7b3ab0eaa5c0927ff1026d31b729bd", ), "shufflenet_v2_x0_5": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_5_pretrained.pdparams", - "e3649cf531566917e2969487d2bc6b60", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_5.pdparams", + "5e5cee182a7793c4e4c73949b1a71bd4", ), "shufflenet_v2_x1_0": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_0_pretrained.pdparams", - "7821c348ea34e58847c43a08a4ac0bdf", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x1_0.pdparams", + "122d42478b9e81eb49f8a9ede327b1a4", ), "shufflenet_v2_x1_5": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_5_pretrained.pdparams", - "93a07fa557ab2d8803550f39e5b6c391", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x1_5.pdparams", + "faced5827380d73531d0ee027c67826d", ), "shufflenet_v2_x2_0": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x2_0_pretrained.pdparams", - "4ab1f622fd0d341e0f84b4e057797563", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x2_0.pdparams", + "cd3dddcd8305e7bcd8ad14d1c69a5784", ), "shufflenet_v2_swish": ( - "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_swish_pretrained.pdparams", - "daff38b3df1b3748fccbb13cfdf02519", ), + "https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_swish.pdparams", + "adde0aa3b023e5b0c94a68be1c394b84", ), } +def create_activation_layer(act): + if act == "swish": + return nn.Swish + elif act == "relu": + return nn.ReLU + elif act is None: + return None + else: + raise RuntimeError( + "The activation function is not supported: {}".format(act)) + + def channel_shuffle(x, groups): batch_size, num_channels, height, width = x.shape[0:4] channels_per_group = num_channels // groups @@ -65,61 +78,37 @@ def channel_shuffle(x, groups): return x -class ConvBNLayer(nn.Layer): +class InvertedResidual(nn.Layer): def __init__(self, in_channels, out_channels, - kernel_size, stride, - padding, - groups=1, - act=None): - super(ConvBNLayer, self).__init__() - self._conv = Conv2D( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=groups, - weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()), - bias_attr=False, ) - - self._batch_norm = BatchNorm(out_channels, act=act) - - def forward(self, inputs): - x = self._conv(inputs) - x = self._batch_norm(x) - return x - - -class InvertedResidual(nn.Layer): - def __init__(self, in_channels, out_channels, stride, act="relu"): + activation_layer=nn.ReLU): super(InvertedResidual, self).__init__() - self._conv_pw = ConvBNLayer( + self._conv_pw = ConvNormActivation( in_channels=in_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) - self._conv_dw = ConvBNLayer( + activation_layer=activation_layer) + self._conv_dw = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=3, stride=stride, padding=1, groups=out_channels // 2, - act=None) - self._conv_linear = ConvBNLayer( + activation_layer=None) + self._conv_linear = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) + activation_layer=activation_layer) def forward(self, inputs): x1, x2 = paddle.split( @@ -134,51 +123,55 @@ class InvertedResidual(nn.Layer): class InvertedResidualDS(nn.Layer): - def __init__(self, in_channels, out_channels, stride, act="relu"): + def __init__(self, + in_channels, + out_channels, + stride, + activation_layer=nn.ReLU): super(InvertedResidualDS, self).__init__() # branch1 - self._conv_dw_1 = ConvBNLayer( + self._conv_dw_1 = ConvNormActivation( in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, - act=None) - self._conv_linear_1 = ConvBNLayer( + activation_layer=None) + self._conv_linear_1 = ConvNormActivation( in_channels=in_channels, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) + activation_layer=activation_layer) # branch2 - self._conv_pw_2 = ConvBNLayer( + self._conv_pw_2 = ConvNormActivation( in_channels=in_channels, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) - self._conv_dw_2 = ConvBNLayer( + activation_layer=activation_layer) + self._conv_dw_2 = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=3, stride=stride, padding=1, groups=out_channels // 2, - act=None) - self._conv_linear_2 = ConvBNLayer( + activation_layer=None) + self._conv_linear_2 = ConvNormActivation( in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1, - act=act) + activation_layer=activation_layer) def forward(self, inputs): x1 = self._conv_dw_1(inputs) @@ -221,6 +214,7 @@ class ShuffleNetV2(nn.Layer): self.num_classes = num_classes self.with_pool = with_pool stage_repeats = [4, 8, 4] + activation_layer = create_activation_layer(act) if scale == 0.25: stage_out_channels = [-1, 24, 24, 48, 96, 512] @@ -238,13 +232,13 @@ class ShuffleNetV2(nn.Layer): raise NotImplementedError("This scale size:[" + str(scale) + "] is not implemented!") # 1. conv1 - self._conv1 = ConvBNLayer( + self._conv1 = ConvNormActivation( in_channels=3, out_channels=stage_out_channels[1], kernel_size=3, stride=2, padding=1, - act=act) + activation_layer=activation_layer) self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) # 2. bottleneck sequences @@ -257,7 +251,7 @@ class ShuffleNetV2(nn.Layer): in_channels=stage_out_channels[stage_id + 1], out_channels=stage_out_channels[stage_id + 2], stride=2, - act=act), + activation_layer=activation_layer), name=str(stage_id + 2) + "_" + str(i + 1)) else: block = self.add_sublayer( @@ -265,17 +259,17 @@ class ShuffleNetV2(nn.Layer): in_channels=stage_out_channels[stage_id + 2], out_channels=stage_out_channels[stage_id + 2], stride=1, - act=act), + activation_layer=activation_layer), name=str(stage_id + 2) + "_" + str(i + 1)) self._block_list.append(block) # 3. last_conv - self._last_conv = ConvBNLayer( + self._last_conv = ConvNormActivation( in_channels=stage_out_channels[-2], out_channels=stage_out_channels[-1], kernel_size=1, stride=1, padding=0, - act=act) + activation_layer=activation_layer) # 4. pool if with_pool: self._pool2d_avg = AdaptiveAvgPool2D(1) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 2d60fd4561..e4dd4c797f 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1335,13 +1335,13 @@ class ConvNormActivation(Sequential): Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block - kernel_size: (int, optional): Size of the convolving kernel. Default: 3 - stride (int, optional): Stride of the convolution. Default: 1 - padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, + kernel_size: (int|list|tuple, optional): Size of the convolving kernel. Default: 3 + stride (int|list|tuple, optional): Stride of the convolution. Default: 1 + padding (int|str|tuple|list, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., paddle.nn.Layer], optional): Norm layer that will be stacked on top of the convolutiuon layer. - If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2d`` + If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2D`` activation_layer (Callable[..., paddle.nn.Layer], optional): Activation function which will be stacked on top of the normalization layer (if not ``None``), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``paddle.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 -- GitLab