未验证 提交 40ac92b2 编写于 作者: N Nyakku Shigure 提交者: GitHub

[cherry-pick] refactor vision models (#42252)

* reuse ConvNormActivation in some vision models (#40431)

* reuse ConvNormActivation in some vision models

* reimplement ResNeXt based on ResNet (#40588)

* refactor resnext
上级 5eba3847
......@@ -34,6 +34,12 @@ from .models import resnet34 # noqa: F401
from .models import resnet50 # noqa: F401
from .models import resnet101 # noqa: F401
from .models import resnet152 # noqa: F401
from .models import resnext50_32x4d # noqa: F401
from .models import resnext50_64x4d # noqa: F401
from .models import resnext101_32x4d # noqa: F401
from .models import resnext101_64x4d # noqa: F401
from .models import resnext152_32x4d # noqa: F401
from .models import resnext152_64x4d # noqa: F401
from .models import wide_resnet50_2 # noqa: F401
from .models import wide_resnet101_2 # noqa: F401
from .models import MobileNetV1 # noqa: F401
......@@ -61,13 +67,6 @@ from .models import densenet201 # noqa: F401
from .models import densenet264 # noqa: F401
from .models import AlexNet # noqa: F401
from .models import alexnet # noqa: F401
from .models import ResNeXt # noqa: F401
from .models import resnext50_32x4d # noqa: F401
from .models import resnext50_64x4d # noqa: F401
from .models import resnext101_32x4d # noqa: F401
from .models import resnext101_64x4d # noqa: F401
from .models import resnext152_32x4d # noqa: F401
from .models import resnext152_64x4d # noqa: F401
from .models import InceptionV3 # noqa: F401
from .models import inception_v3 # noqa: F401
from .models import GoogLeNet # noqa: F401
......
......@@ -18,6 +18,12 @@ from .resnet import resnet34 # noqa: F401
from .resnet import resnet50 # noqa: F401
from .resnet import resnet101 # noqa: F401
from .resnet import resnet152 # noqa: F401
from .resnet import resnext50_32x4d # noqa: F401
from .resnet import resnext50_64x4d # noqa: F401
from .resnet import resnext101_32x4d # noqa: F401
from .resnet import resnext101_64x4d # noqa: F401
from .resnet import resnext152_32x4d # noqa: F401
from .resnet import resnext152_64x4d # noqa: F401
from .resnet import wide_resnet50_2 # noqa: F401
from .resnet import wide_resnet101_2 # noqa: F401
from .mobilenetv1 import MobileNetV1 # noqa: F401
......@@ -42,13 +48,6 @@ from .densenet import densenet201 # noqa: F401
from .densenet import densenet264 # noqa: F401
from .alexnet import AlexNet # noqa: F401
from .alexnet import alexnet # noqa: F401
from .resnext import ResNeXt # noqa: F401
from .resnext import resnext50_32x4d # noqa: F401
from .resnext import resnext50_64x4d # noqa: F401
from .resnext import resnext101_32x4d # noqa: F401
from .resnext import resnext101_64x4d # noqa: F401
from .resnext import resnext152_32x4d # noqa: F401
from .resnext import resnext152_64x4d # noqa: F401
from .inceptionv3 import InceptionV3 # noqa: F401
from .inceptionv3 import inception_v3 # noqa: F401
from .squeezenet import SqueezeNet # noqa: F401
......@@ -72,6 +71,12 @@ __all__ = [ #noqa
'resnet50',
'resnet101',
'resnet152',
'resnext50_32x4d',
'resnext50_64x4d',
'resnext101_32x4d',
'resnext101_64x4d',
'resnext152_32x4d',
'resnext152_64x4d',
'wide_resnet50_2',
'wide_resnet101_2',
'VGG',
......@@ -96,13 +101,6 @@ __all__ = [ #noqa
'densenet264',
'AlexNet',
'alexnet',
'ResNeXt',
'resnext50_32x4d',
'resnext50_64x4d',
'resnext101_32x4d',
'resnext101_64x4d',
'resnext152_32x4d',
'resnext152_64x4d',
'InceptionV3',
'inception_v3',
'SqueezeNet',
......
......@@ -19,75 +19,60 @@ from __future__ import print_function
import math
import paddle
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
from paddle.fluid.param_attr import ParamAttr
from paddle.utils.download import get_weights_path_from_url
from ..ops import ConvNormActivation
__all__ = []
model_urls = {
"inception_v3":
("https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams",
"e4d0905a818f6bb7946e881777a8a935")
("https://paddle-hapi.bj.bcebos.com/models/inception_v3.pdparams",
"649a4547c3243e8b59c656f41fe330b8")
}
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
groups=1,
act="relu"):
super().__init__()
self.act = act
self.conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn = BatchNorm(num_filters)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.act:
x = self.relu(x)
return x
class InceptionStem(nn.Layer):
def __init__(self):
super().__init__()
self.conv_1a_3x3 = ConvBNLayer(
num_channels=3, num_filters=32, filter_size=3, stride=2, act="relu")
self.conv_2a_3x3 = ConvBNLayer(
num_channels=32,
num_filters=32,
filter_size=3,
self.conv_1a_3x3 = ConvNormActivation(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=2,
padding=0,
activation_layer=nn.ReLU)
self.conv_2a_3x3 = ConvNormActivation(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act="relu")
self.conv_2b_3x3 = ConvBNLayer(
num_channels=32,
num_filters=64,
filter_size=3,
padding=0,
activation_layer=nn.ReLU)
self.conv_2b_3x3 = ConvNormActivation(
in_channels=32,
out_channels=64,
kernel_size=3,
padding=1,
act="relu")
activation_layer=nn.ReLU)
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
self.conv_3b_1x1 = ConvBNLayer(
num_channels=64, num_filters=80, filter_size=1, act="relu")
self.conv_4a_3x3 = ConvBNLayer(
num_channels=80, num_filters=192, filter_size=3, act="relu")
self.conv_3b_1x1 = ConvNormActivation(
in_channels=64,
out_channels=80,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.conv_4a_3x3 = ConvNormActivation(
in_channels=80,
out_channels=192,
kernel_size=3,
padding=0,
activation_layer=nn.ReLU)
def forward(self, x):
x = self.conv_1a_3x3(x)
......@@ -103,47 +88,53 @@ class InceptionStem(nn.Layer):
class InceptionA(nn.Layer):
def __init__(self, num_channels, pool_features):
super().__init__()
self.branch1x1 = ConvBNLayer(
num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch5x5_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=48,
filter_size=1,
act="relu")
self.branch5x5_2 = ConvBNLayer(
num_channels=48,
num_filters=64,
filter_size=5,
self.branch1x1 = ConvNormActivation(
in_channels=num_channels,
out_channels=64,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch5x5_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=48,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch5x5_2 = ConvNormActivation(
in_channels=48,
out_channels=64,
kernel_size=5,
padding=2,
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(
num_channels=64,
num_filters=96,
filter_size=3,
activation_layer=nn.ReLU)
self.branch3x3dbl_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=64,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3dbl_2 = ConvNormActivation(
in_channels=64,
out_channels=96,
kernel_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3 = ConvBNLayer(
num_channels=96,
num_filters=96,
filter_size=3,
activation_layer=nn.ReLU)
self.branch3x3dbl_3 = ConvNormActivation(
in_channels=96,
out_channels=96,
kernel_size=3,
padding=1,
act="relu")
activation_layer=nn.ReLU)
self.branch_pool = AvgPool2D(
kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=pool_features,
filter_size=1,
act="relu")
self.branch_pool_conv = ConvNormActivation(
in_channels=num_channels,
out_channels=pool_features,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
def forward(self, x):
branch1x1 = self.branch1x1(x)
......@@ -164,29 +155,34 @@ class InceptionA(nn.Layer):
class InceptionB(nn.Layer):
def __init__(self, num_channels):
super().__init__()
self.branch3x3 = ConvBNLayer(
num_channels=num_channels,
num_filters=384,
filter_size=3,
self.branch3x3 = ConvNormActivation(
in_channels=num_channels,
out_channels=384,
kernel_size=3,
stride=2,
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(
num_channels=64,
num_filters=96,
filter_size=3,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3dbl_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=64,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3dbl_2 = ConvNormActivation(
in_channels=64,
out_channels=96,
kernel_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3 = ConvBNLayer(
num_channels=96,
num_filters=96,
filter_size=3,
activation_layer=nn.ReLU)
self.branch3x3dbl_3 = ConvNormActivation(
in_channels=96,
out_channels=96,
kernel_size=3,
stride=2,
act="relu")
padding=0,
activation_layer=nn.ReLU)
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x):
......@@ -206,70 +202,74 @@ class InceptionB(nn.Layer):
class InceptionC(nn.Layer):
def __init__(self, num_channels, channels_7x7):
super().__init__()
self.branch1x1 = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch7x7_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=channels_7x7,
filter_size=1,
self.branch1x1 = ConvNormActivation(
in_channels=num_channels,
out_channels=192,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch7x7_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=channels_7x7,
kernel_size=1,
stride=1,
act="relu")
self.branch7x7_2 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(1, 7),
padding=0,
activation_layer=nn.ReLU)
self.branch7x7_2 = ConvNormActivation(
in_channels=channels_7x7,
out_channels=channels_7x7,
kernel_size=(1, 7),
stride=1,
padding=(0, 3),
act="relu")
self.branch7x7_3 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=192,
filter_size=(7, 1),
activation_layer=nn.ReLU)
self.branch7x7_3 = ConvNormActivation(
in_channels=channels_7x7,
out_channels=192,
kernel_size=(7, 1),
stride=1,
padding=(3, 0),
act="relu")
self.branch7x7dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=channels_7x7,
filter_size=1,
act="relu")
self.branch7x7dbl_2 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(7, 1),
activation_layer=nn.ReLU)
self.branch7x7dbl_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=channels_7x7,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch7x7dbl_2 = ConvNormActivation(
in_channels=channels_7x7,
out_channels=channels_7x7,
kernel_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7dbl_3 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(1, 7),
activation_layer=nn.ReLU)
self.branch7x7dbl_3 = ConvNormActivation(
in_channels=channels_7x7,
out_channels=channels_7x7,
kernel_size=(1, 7),
padding=(0, 3),
act="relu")
self.branch7x7dbl_4 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(7, 1),
activation_layer=nn.ReLU)
self.branch7x7dbl_4 = ConvNormActivation(
in_channels=channels_7x7,
out_channels=channels_7x7,
kernel_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7dbl_5 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=192,
filter_size=(1, 7),
activation_layer=nn.ReLU)
self.branch7x7dbl_5 = ConvNormActivation(
in_channels=channels_7x7,
out_channels=192,
kernel_size=(1, 7),
padding=(0, 3),
act="relu")
activation_layer=nn.ReLU)
self.branch_pool = AvgPool2D(
kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch_pool_conv = ConvNormActivation(
in_channels=num_channels,
out_channels=192,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
def forward(self, x):
branch1x1 = self.branch1x1(x)
......@@ -296,40 +296,46 @@ class InceptionC(nn.Layer):
class InceptionD(nn.Layer):
def __init__(self, num_channels):
super().__init__()
self.branch3x3_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch3x3_2 = ConvBNLayer(
num_channels=192,
num_filters=320,
filter_size=3,
self.branch3x3_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=192,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3_2 = ConvNormActivation(
in_channels=192,
out_channels=320,
kernel_size=3,
stride=2,
act="relu")
self.branch7x7x3_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch7x7x3_2 = ConvBNLayer(
num_channels=192,
num_filters=192,
filter_size=(1, 7),
padding=0,
activation_layer=nn.ReLU)
self.branch7x7x3_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=192,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch7x7x3_2 = ConvNormActivation(
in_channels=192,
out_channels=192,
kernel_size=(1, 7),
padding=(0, 3),
act="relu")
self.branch7x7x3_3 = ConvBNLayer(
num_channels=192,
num_filters=192,
filter_size=(7, 1),
activation_layer=nn.ReLU)
self.branch7x7x3_3 = ConvNormActivation(
in_channels=192,
out_channels=192,
kernel_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7x3_4 = ConvBNLayer(
num_channels=192,
num_filters=192,
filter_size=3,
activation_layer=nn.ReLU)
self.branch7x7x3_4 = ConvNormActivation(
in_channels=192,
out_channels=192,
kernel_size=3,
stride=2,
act="relu")
padding=0,
activation_layer=nn.ReLU)
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x):
......@@ -350,59 +356,64 @@ class InceptionD(nn.Layer):
class InceptionE(nn.Layer):
def __init__(self, num_channels):
super().__init__()
self.branch1x1 = ConvBNLayer(
num_channels=num_channels,
num_filters=320,
filter_size=1,
act="relu")
self.branch3x3_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=384,
filter_size=1,
act="relu")
self.branch3x3_2a = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(1, 3),
self.branch1x1 = ConvNormActivation(
in_channels=num_channels,
out_channels=320,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=384,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3_2a = ConvNormActivation(
in_channels=384,
out_channels=384,
kernel_size=(1, 3),
padding=(0, 1),
act="relu")
self.branch3x3_2b = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(3, 1),
activation_layer=nn.ReLU)
self.branch3x3_2b = ConvNormActivation(
in_channels=384,
out_channels=384,
kernel_size=(3, 1),
padding=(1, 0),
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=448,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(
num_channels=448,
num_filters=384,
filter_size=3,
activation_layer=nn.ReLU)
self.branch3x3dbl_1 = ConvNormActivation(
in_channels=num_channels,
out_channels=448,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
self.branch3x3dbl_2 = ConvNormActivation(
in_channels=448,
out_channels=384,
kernel_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3a = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(1, 3),
activation_layer=nn.ReLU)
self.branch3x3dbl_3a = ConvNormActivation(
in_channels=384,
out_channels=384,
kernel_size=(1, 3),
padding=(0, 1),
act="relu")
self.branch3x3dbl_3b = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(3, 1),
activation_layer=nn.ReLU)
self.branch3x3dbl_3b = ConvNormActivation(
in_channels=384,
out_channels=384,
kernel_size=(3, 1),
padding=(1, 0),
act="relu")
activation_layer=nn.ReLU)
self.branch_pool = AvgPool2D(
kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch_pool_conv = ConvNormActivation(
in_channels=num_channels,
out_channels=192,
kernel_size=1,
padding=0,
activation_layer=nn.ReLU)
def forward(self, x):
branch1x1 = self.branch1x1(x)
......
......@@ -16,59 +16,31 @@ import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
from ..ops import ConvNormActivation
__all__ = []
model_urls = {
'mobilenetv1_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams',
'42a154c2f26f86e7457d6daded114e8c')
('https://paddle-hapi.bj.bcebos.com/models/mobilenetv1_1.0.pdparams',
'3033ab1975b1670bef51545feb65fc45')
}
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
num_groups=1):
super(ConvBNLayer, self).__init__()
self._conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
groups=num_groups,
bias_attr=False)
self._norm_layer = nn.BatchNorm2D(out_channels)
self._act = nn.ReLU()
def forward(self, x):
x = self._conv(x)
x = self._norm_layer(x)
x = self._act(x)
return x
class DepthwiseSeparable(nn.Layer):
def __init__(self, in_channels, out_channels1, out_channels2, num_groups,
stride, scale):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
self._depthwise_conv = ConvNormActivation(
in_channels,
int(out_channels1 * scale),
kernel_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale))
groups=int(num_groups * scale))
self._pointwise_conv = ConvBNLayer(
self._pointwise_conv = ConvNormActivation(
int(out_channels1 * scale),
int(out_channels2 * scale),
kernel_size=1,
......@@ -94,9 +66,15 @@ class MobileNetV1(nn.Layer):
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import MobileNetV1
model = MobileNetV1()
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
......@@ -106,7 +84,7 @@ class MobileNetV1(nn.Layer):
self.num_classes = num_classes
self.with_pool = with_pool
self.conv1 = ConvBNLayer(
self.conv1 = ConvNormActivation(
in_channels=3,
out_channels=int(32 * scale),
kernel_size=3,
......@@ -257,6 +235,7 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs):
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
# build model
......@@ -266,7 +245,12 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs):
# model = mobilenet_v1(pretrained=True)
# build mobilenet v1 with scale=0.5
model = mobilenet_v1(scale=0.5)
model_scale = mobilenet_v1(scale=0.5)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
model = _mobilenet(
'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs)
......
......@@ -17,6 +17,7 @@ import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
from .utils import _make_divisible
from ..ops import ConvNormActivation
__all__ = []
......@@ -27,29 +28,6 @@ model_urls = {
}
class ConvBNReLU(nn.Sequential):
def __init__(self,
in_planes,
out_planes,
kernel_size=3,
stride=1,
groups=1,
norm_layer=nn.BatchNorm2D):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2D(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias_attr=False),
norm_layer(out_planes),
nn.ReLU6())
class InvertedResidual(nn.Layer):
def __init__(self,
inp,
......@@ -67,15 +45,20 @@ class InvertedResidual(nn.Layer):
layers = []
if expand_ratio != 1:
layers.append(
ConvBNReLU(
inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
ConvNormActivation(
inp,
hidden_dim,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.ReLU6))
layers.extend([
ConvBNReLU(
ConvNormActivation(
hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
norm_layer=norm_layer),
norm_layer=norm_layer,
activation_layer=nn.ReLU6),
nn.Conv2D(
hidden_dim, oup, 1, 1, 0, bias_attr=False),
norm_layer(oup),
......@@ -90,23 +73,30 @@ class InvertedResidual(nn.Layer):
class MobileNetV2(nn.Layer):
def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
"""MobileNetV2 model from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
"""MobileNetV2 model from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import MobileNetV2
Examples:
.. code-block:: python
model = MobileNetV2()
from paddle.vision.models import MobileNetV2
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
model = MobileNetV2()
"""
def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
super(MobileNetV2, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
......@@ -130,8 +120,12 @@ class MobileNetV2(nn.Layer):
self.last_channel = _make_divisible(last_channel * max(1.0, scale),
round_nearest)
features = [
ConvBNReLU(
3, input_channel, stride=2, norm_layer=norm_layer)
ConvNormActivation(
3,
input_channel,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.ReLU6)
]
for t, c, n, s in inverted_residual_setting:
......@@ -148,11 +142,12 @@ class MobileNetV2(nn.Layer):
input_channel = output_channel
features.append(
ConvBNReLU(
ConvNormActivation(
input_channel,
self.last_channel,
kernel_size=1,
norm_layer=norm_layer))
norm_layer=norm_layer,
activation_layer=nn.ReLU6))
self.features = nn.Sequential(*features)
......@@ -199,6 +194,7 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v2
# build model
......@@ -209,6 +205,11 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
# build mobilenet v2 with scale=0.5
model = mobilenet_v2(scale=0.5)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
model = _mobilenet(
'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs)
......
......@@ -33,12 +33,30 @@ model_urls = {
'02f35f034ca3858e1e54d4036443c92d'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
'7ad16a2f1e7333859ff986138630fd7a'),
'wide_resnet50_2':
('https://paddle-hapi.bj.bcebos.com/models/wide_resnet50_2.pdparams',
'0282f804d73debdab289bd9fea3fa6dc'),
'wide_resnet101_2':
('https://paddle-hapi.bj.bcebos.com/models/wide_resnet101_2.pdparams',
'd4360a2d23657f059216f5d5a1a9ac93'),
'resnext50_32x4d':
('https://paddle-hapi.bj.bcebos.com/models/resnext50_32x4d.pdparams',
'dc47483169be7d6f018fcbb7baf8775d'),
"resnext50_64x4d":
('https://paddle-hapi.bj.bcebos.com/models/resnext50_64x4d.pdparams',
'063d4b483e12b06388529450ad7576db'),
'resnext101_32x4d': (
'https://paddle-hapi.bj.bcebos.com/models/resnext101_32x4d.pdparams',
'967b090039f9de2c8d06fe994fb9095f'),
'resnext101_64x4d': (
'https://paddle-hapi.bj.bcebos.com/models/resnext101_64x4d.pdparams',
'98e04e7ca616a066699230d769d03008'),
'resnext152_32x4d': (
'https://paddle-hapi.bj.bcebos.com/models/resnext152_32x4d.pdparams',
'18ff0beee21f2efc99c4b31786107121'),
'resnext152_64x4d': (
'https://paddle-hapi.bj.bcebos.com/models/resnext152_64x4d.pdparams',
'77c4af00ca42c405fa7f841841959379'),
'wide_resnet50_2': (
'https://paddle-hapi.bj.bcebos.com/models/wide_resnet50_2.pdparams',
'0282f804d73debdab289bd9fea3fa6dc'),
'wide_resnet101_2': (
'https://paddle-hapi.bj.bcebos.com/models/wide_resnet101_2.pdparams',
'd4360a2d23657f059216f5d5a1a9ac93'),
}
......@@ -158,11 +176,12 @@ class ResNet(nn.Layer):
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
width (int): base width of resnet, default: 64.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
depth (int, optional): layers of resnet, Default: 50.
width (int, optional): base width per convolution group for each convolution block, Default: 64.
num_classes (int, optional): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
with_pool (bool, optional): use pool before the last fc layer or not. Default: True.
groups (int, optional): number of groups for each convolution block, Default: 1.
Examples:
.. code-block:: python
......@@ -171,16 +190,23 @@ class ResNet(nn.Layer):
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
# build ResNet with 18 layers
resnet18 = ResNet(BasicBlock, 18)
# build ResNet with 50 layers
resnet50 = ResNet(BottleneckBlock, 50)
# build Wide ResNet model
wide_resnet50_2 = ResNet(BottleneckBlock, 50, width=64*2)
resnet18 = ResNet(BasicBlock, 18)
# build ResNeXt model
resnext50_32x4d = ResNet(BottleneckBlock, 50, width=4, groups=32)
x = paddle.rand([1, 3, 224, 224])
out = resnet18(x)
print(out.shape)
# [1, 1000]
"""
......@@ -189,7 +215,8 @@ class ResNet(nn.Layer):
depth=50,
width=64,
num_classes=1000,
with_pool=True):
with_pool=True,
groups=1):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
......@@ -199,7 +226,7 @@ class ResNet(nn.Layer):
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.groups = 1
self.groups = groups
self.base_width = width
self.num_classes = num_classes
self.with_pool = with_pool
......@@ -300,7 +327,7 @@ def resnet18(pretrained=False, **kwargs):
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -318,6 +345,7 @@ def resnet18(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
......@@ -327,7 +355,7 @@ def resnet34(pretrained=False, **kwargs):
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -345,6 +373,7 @@ def resnet34(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
......@@ -354,7 +383,7 @@ def resnet50(pretrained=False, **kwargs):
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -372,6 +401,7 @@ def resnet50(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
......@@ -381,7 +411,7 @@ def resnet101(pretrained=False, **kwargs):
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -399,6 +429,7 @@ def resnet101(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
......@@ -408,7 +439,7 @@ def resnet152(pretrained=False, **kwargs):
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -426,16 +457,201 @@ def resnet152(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
def resnext50_32x4d(pretrained=False, **kwargs):
"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext50_32x4d
# build model
model = resnext50_32x4d()
# build model and load imagenet pretrained weight
# model = resnext50_32x4d(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['groups'] = 32
kwargs['width'] = 4
return _resnet('resnext50_32x4d', BottleneckBlock, 50, pretrained, **kwargs)
def resnext50_64x4d(pretrained=False, **kwargs):
"""ResNeXt-50 64x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext50_64x4d
# build model
model = resnext50_64x4d()
# build model and load imagenet pretrained weight
# model = resnext50_64x4d(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['groups'] = 64
kwargs['width'] = 4
return _resnet('resnext50_64x4d', BottleneckBlock, 50, pretrained, **kwargs)
def resnext101_32x4d(pretrained=False, **kwargs):
"""ResNeXt-101 32x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext101_32x4d
# build model
model = resnext101_32x4d()
# build model and load imagenet pretrained weight
# model = resnext101_32x4d(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['groups'] = 32
kwargs['width'] = 4
return _resnet('resnext101_32x4d', BottleneckBlock, 101, pretrained,
**kwargs)
def resnext101_64x4d(pretrained=False, **kwargs):
"""ResNeXt-101 64x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext101_64x4d
# build model
model = resnext101_64x4d()
# build model and load imagenet pretrained weight
# model = resnext101_64x4d(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['groups'] = 64
kwargs['width'] = 4
return _resnet('resnext101_64x4d', BottleneckBlock, 101, pretrained,
**kwargs)
def resnext152_32x4d(pretrained=False, **kwargs):
"""ResNeXt-152 32x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext152_32x4d
# build model
model = resnext152_32x4d()
# build model and load imagenet pretrained weight
# model = resnext152_32x4d(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['groups'] = 32
kwargs['width'] = 4
return _resnet('resnext152_32x4d', BottleneckBlock, 152, pretrained,
**kwargs)
def resnext152_64x4d(pretrained=False, **kwargs):
"""ResNeXt-152 64x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext152_64x4d
# build model
model = resnext152_64x4d()
# build model and load imagenet pretrained weight
# model = resnext152_64x4d(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['groups'] = 64
kwargs['width'] = 4
return _resnet('resnext152_64x4d', BottleneckBlock, 152, pretrained,
**kwargs)
def wide_resnet50_2(pretrained=False, **kwargs):
"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -453,6 +669,7 @@ def wide_resnet50_2(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['width'] = 64 * 2
return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs)
......@@ -463,7 +680,7 @@ def wide_resnet101_2(pretrained=False, **kwargs):
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
Examples:
.. code-block:: python
......@@ -481,6 +698,7 @@ def wide_resnet101_2(pretrained=False, **kwargs):
out = model(x)
print(out.shape)
# [1, 1000]
"""
kwargs['width'] = 64 * 2
return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.fluid.param_attr import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear, MaxPool2D
from paddle.nn.initializer import Uniform
from paddle.utils.download import get_weights_path_from_url
__all__ = []
model_urls = {
'resnext50_32x4d':
('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt50_32x4d_pretrained.pdparams',
'bf04add2f7fd22efcbe91511bcd1eebe'),
"resnext50_64x4d":
('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt50_64x4d_pretrained.pdparams',
'46307df0e2d6d41d3b1c1d22b00abc69'),
'resnext101_32x4d':
('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x4d_pretrained.pdparams',
'078ca145b3bea964ba0544303a43c36d'),
'resnext101_64x4d':
('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_64x4d_pretrained.pdparams',
'4edc0eb32d3cc5d80eff7cab32cd5c64'),
'resnext152_32x4d':
('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt152_32x4d_pretrained.pdparams',
'7971cc994d459af167c502366f866378'),
'resnext152_64x4d':
('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt152_64x4d_pretrained.pdparams',
'836943f03709efec364d486c57d132de'),
}
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
cardinality,
shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
groups=cardinality,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 2 if cardinality == 32 else num_filters,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 2
if cardinality == 32 else num_filters,
filter_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = paddle.add(x=short, y=conv2)
x = F.relu(x)
return x
class ResNeXt(nn.Layer):
"""ResNeXt model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
depth (int, optional): depth of resnext. Default: 50.
cardinality (int, optional): cardinality of resnext. Default: 32.
num_classes (int, optional): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool, optional): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import ResNeXt
resnext50_32x4d = ResNeXt(depth=50, cardinality=32)
"""
def __init__(self,
depth=50,
cardinality=32,
num_classes=1000,
with_pool=True):
super(ResNeXt, self).__init__()
self.depth = depth
self.cardinality = cardinality
self.num_classes = num_classes
self.with_pool = with_pool
supported_depth = [50, 101, 152]
assert depth in supported_depth, \
"supported layers are {} but input layer is {}".format(
supported_depth, depth)
supported_cardinality = [32, 64]
assert cardinality in supported_cardinality, \
"supported cardinality is {} but input cardinality is {}" \
.format(supported_cardinality, cardinality)
layer_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}
layers = layer_cfg[depth]
num_channels = [64, 256, 512, 1024]
num_filters = [128, 256, 512,
1024] if cardinality == 32 else [256, 512, 1024, 2048]
self.conv = ConvBNLayer(
num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
self.block_list = []
for block in range(len(layers)):
shortcut = False
for i in range(layers[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block] if i == 0 else
num_filters[block] * int(64 // self.cardinality),
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
cardinality=self.cardinality,
shortcut=shortcut))
self.block_list.append(bottleneck_block)
shortcut = True
if with_pool:
self.pool2d_avg = AdaptiveAvgPool2D(1)
if num_classes > 0:
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(
self.pool2d_avg_channels,
num_classes,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
def forward(self, inputs):
with paddle.static.amp.fp16_guard():
x = self.conv(inputs)
x = self.pool2d_max(x)
for block in self.block_list:
x = block(x)
if self.with_pool:
x = self.pool2d_avg(x)
if self.num_classes > 0:
x = paddle.reshape(x, shape=[-1, self.pool2d_avg_channels])
x = self.out(x)
return x
def _resnext(arch, depth, cardinality, pretrained, **kwargs):
model = ResNeXt(depth=depth, cardinality=cardinality, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnext50_32x4d(pretrained=False, **kwargs):
"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext50_32x4d
# build model
model = resnext50_32x4d()
# build model and load imagenet pretrained weight
# model = resnext50_32x4d(pretrained=True)
"""
return _resnext('resnext50_32x4d', 50, 32, pretrained, **kwargs)
def resnext50_64x4d(pretrained=False, **kwargs):
"""ResNeXt-50 64x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext50_64x4d
# build model
model = resnext50_64x4d()
# build model and load imagenet pretrained weight
# model = resnext50_64x4d(pretrained=True)
"""
return _resnext('resnext50_64x4d', 50, 64, pretrained, **kwargs)
def resnext101_32x4d(pretrained=False, **kwargs):
"""ResNeXt-101 32x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext101_32x4d
# build model
model = resnext101_32x4d()
# build model and load imagenet pretrained weight
# model = resnext101_32x4d(pretrained=True)
"""
return _resnext('resnext101_32x4d', 101, 32, pretrained, **kwargs)
def resnext101_64x4d(pretrained=False, **kwargs):
"""ResNeXt-101 64x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext101_64x4d
# build model
model = resnext101_64x4d()
# build model and load imagenet pretrained weight
# model = resnext101_64x4d(pretrained=True)
"""
return _resnext('resnext101_64x4d', 101, 64, pretrained, **kwargs)
def resnext152_32x4d(pretrained=False, **kwargs):
"""ResNeXt-152 32x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext152_32x4d
# build model
model = resnext152_32x4d()
# build model and load imagenet pretrained weight
# model = resnext152_32x4d(pretrained=True)
"""
return _resnext('resnext152_32x4d', 152, 32, pretrained, **kwargs)
def resnext152_64x4d(pretrained=False, **kwargs):
"""ResNeXt-152 64x4d model from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import resnext152_64x4d
# build model
model = resnext152_64x4d()
# build model and load imagenet pretrained weight
# model = resnext152_64x4d(pretrained=True)
"""
return _resnext('resnext152_64x4d', 152, 64, pretrained, **kwargs)
......@@ -18,37 +18,50 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.fluid.param_attr import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear, MaxPool2D
from paddle.nn import AdaptiveAvgPool2D, Linear, MaxPool2D
from paddle.utils.download import get_weights_path_from_url
from ..ops import ConvNormActivation
__all__ = []
model_urls = {
"shufflenet_v2_x0_25": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_25_pretrained.pdparams",
"e753404cbd95027759c5f56ecd6c9c4b", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_25.pdparams",
"1e509b4c140eeb096bb16e214796d03b", ),
"shufflenet_v2_x0_33": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_33_pretrained.pdparams",
"776e3cf9a4923abdfce789c45b8fe1f2", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_33.pdparams",
"3d7b3ab0eaa5c0927ff1026d31b729bd", ),
"shufflenet_v2_x0_5": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_5_pretrained.pdparams",
"e3649cf531566917e2969487d2bc6b60", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x0_5.pdparams",
"5e5cee182a7793c4e4c73949b1a71bd4", ),
"shufflenet_v2_x1_0": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_0_pretrained.pdparams",
"7821c348ea34e58847c43a08a4ac0bdf", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x1_0.pdparams",
"122d42478b9e81eb49f8a9ede327b1a4", ),
"shufflenet_v2_x1_5": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_5_pretrained.pdparams",
"93a07fa557ab2d8803550f39e5b6c391", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x1_5.pdparams",
"faced5827380d73531d0ee027c67826d", ),
"shufflenet_v2_x2_0": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x2_0_pretrained.pdparams",
"4ab1f622fd0d341e0f84b4e057797563", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_x2_0.pdparams",
"cd3dddcd8305e7bcd8ad14d1c69a5784", ),
"shufflenet_v2_swish": (
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_swish_pretrained.pdparams",
"daff38b3df1b3748fccbb13cfdf02519", ),
"https://paddle-hapi.bj.bcebos.com/models/shufflenet_v2_swish.pdparams",
"adde0aa3b023e5b0c94a68be1c394b84", ),
}
def create_activation_layer(act):
if act == "swish":
return nn.Swish
elif act == "relu":
return nn.ReLU
elif act is None:
return None
else:
raise RuntimeError(
"The activation function is not supported: {}".format(act))
def channel_shuffle(x, groups):
batch_size, num_channels, height, width = x.shape[0:4]
channels_per_group = num_channels // groups
......@@ -65,61 +78,37 @@ def channel_shuffle(x, groups):
return x
class ConvBNLayer(nn.Layer):
class InvertedResidual(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
bias_attr=False, )
self._batch_norm = BatchNorm(out_channels, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class InvertedResidual(nn.Layer):
def __init__(self, in_channels, out_channels, stride, act="relu"):
activation_layer=nn.ReLU):
super(InvertedResidual, self).__init__()
self._conv_pw = ConvBNLayer(
self._conv_pw = ConvNormActivation(
in_channels=in_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
self._conv_dw = ConvBNLayer(
activation_layer=activation_layer)
self._conv_dw = ConvNormActivation(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=out_channels // 2,
act=None)
self._conv_linear = ConvBNLayer(
activation_layer=None)
self._conv_linear = ConvNormActivation(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
activation_layer=activation_layer)
def forward(self, inputs):
x1, x2 = paddle.split(
......@@ -134,51 +123,55 @@ class InvertedResidual(nn.Layer):
class InvertedResidualDS(nn.Layer):
def __init__(self, in_channels, out_channels, stride, act="relu"):
def __init__(self,
in_channels,
out_channels,
stride,
activation_layer=nn.ReLU):
super(InvertedResidualDS, self).__init__()
# branch1
self._conv_dw_1 = ConvBNLayer(
self._conv_dw_1 = ConvNormActivation(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
act=None)
self._conv_linear_1 = ConvBNLayer(
activation_layer=None)
self._conv_linear_1 = ConvNormActivation(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
activation_layer=activation_layer)
# branch2
self._conv_pw_2 = ConvBNLayer(
self._conv_pw_2 = ConvNormActivation(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
self._conv_dw_2 = ConvBNLayer(
activation_layer=activation_layer)
self._conv_dw_2 = ConvNormActivation(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=out_channels // 2,
act=None)
self._conv_linear_2 = ConvBNLayer(
activation_layer=None)
self._conv_linear_2 = ConvNormActivation(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
activation_layer=activation_layer)
def forward(self, inputs):
x1 = self._conv_dw_1(inputs)
......@@ -221,6 +214,7 @@ class ShuffleNetV2(nn.Layer):
self.num_classes = num_classes
self.with_pool = with_pool
stage_repeats = [4, 8, 4]
activation_layer = create_activation_layer(act)
if scale == 0.25:
stage_out_channels = [-1, 24, 24, 48, 96, 512]
......@@ -238,13 +232,13 @@ class ShuffleNetV2(nn.Layer):
raise NotImplementedError("This scale size:[" + str(scale) +
"] is not implemented!")
# 1. conv1
self._conv1 = ConvBNLayer(
self._conv1 = ConvNormActivation(
in_channels=3,
out_channels=stage_out_channels[1],
kernel_size=3,
stride=2,
padding=1,
act=act)
activation_layer=activation_layer)
self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
# 2. bottleneck sequences
......@@ -257,7 +251,7 @@ class ShuffleNetV2(nn.Layer):
in_channels=stage_out_channels[stage_id + 1],
out_channels=stage_out_channels[stage_id + 2],
stride=2,
act=act),
activation_layer=activation_layer),
name=str(stage_id + 2) + "_" + str(i + 1))
else:
block = self.add_sublayer(
......@@ -265,17 +259,17 @@ class ShuffleNetV2(nn.Layer):
in_channels=stage_out_channels[stage_id + 2],
out_channels=stage_out_channels[stage_id + 2],
stride=1,
act=act),
activation_layer=activation_layer),
name=str(stage_id + 2) + "_" + str(i + 1))
self._block_list.append(block)
# 3. last_conv
self._last_conv = ConvBNLayer(
self._last_conv = ConvNormActivation(
in_channels=stage_out_channels[-2],
out_channels=stage_out_channels[-1],
kernel_size=1,
stride=1,
padding=0,
act=act)
activation_layer=activation_layer)
# 4. pool
if with_pool:
self._pool2d_avg = AdaptiveAvgPool2D(1)
......
......@@ -1335,13 +1335,13 @@ class ConvNormActivation(Sequential):
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None,
kernel_size: (int|list|tuple, optional): Size of the convolving kernel. Default: 3
stride (int|list|tuple, optional): Stride of the convolution. Default: 1
padding (int|str|tuple|list, optional): Padding added to all four sides of the input. Default: None,
in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., paddle.nn.Layer], optional): Norm layer that will be stacked on top of the convolutiuon layer.
If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2d``
If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2D``
activation_layer (Callable[..., paddle.nn.Layer], optional): Activation function which will be stacked on top of the normalization
layer (if not ``None``), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``paddle.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册