diff --git a/ppcls/arch/backbone/legendary_models/__init__.py b/ppcls/arch/backbone/legendary_models/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..1f837dac833746b8e87bb5f180ef32a16dbb1ad9 100644 --- a/ppcls/arch/backbone/legendary_models/__init__.py +++ b/ppcls/arch/backbone/legendary_models/__init__.py @@ -0,0 +1,6 @@ +from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNet18_vd, ResNet34_vd, ResNet50_vd, ResNet101_vd, ResNet152_vd +from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W64_C +from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1 +from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 +from .inception_v3 import InceptionV3 +from .vgg import VGG11, VGG13, VGG16, VGG19 diff --git a/ppcls/arch/backbone/legendary_models/hrnet.py b/ppcls/arch/backbone/legendary_models/hrnet.py index 8fe291e135eac46b04b4e86eb7d59f769e4213e2..51ad4e4f51b7104de30962d72849c9e032229e67 100644 --- a/ppcls/arch/backbone/legendary_models/hrnet.py +++ b/ppcls/arch/backbone/legendary_models/hrnet.py @@ -24,29 +24,40 @@ from paddle.nn.functional import upsample from paddle.nn.initializer import Uniform from ppcls.arch.backbone.base.theseus_layer import TheseusLayer, Identity +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url MODEL_URLS = { - "HRNet_W18_C": "", - "HRNet_W30_C": "", - "HRNet_W32_C": "", - "HRNet_W40_C": "", - "HRNet_W44_C": "", - "HRNet_W48_C": "", - "HRNet_W60_C": "", - "HRNet_W64_C": "", - "SE_HRNet_W18_C": "", - "SE_HRNet_W30_C": "", - "SE_HRNet_W32_C": "", - "SE_HRNet_W40_C": "", - "SE_HRNet_W44_C": "", - "SE_HRNet_W48_C": "", - "SE_HRNet_W60_C": "", - "SE_HRNet_W64_C": "", + "HRNet_W18_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W18_C_pretrained.pdparams", + "HRNet_W30_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W30_C_pretrained.pdparams", + "HRNet_W32_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W32_C_pretrained.pdparams", + "HRNet_W40_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W40_C_pretrained.pdparams", + "HRNet_W44_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W44_C_pretrained.pdparams", + "HRNet_W48_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W48_C_pretrained.pdparams", + "HRNet_W64_C": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W64_C_pretrained.pdparams" } __all__ = list(MODEL_URLS.keys()) +def _create_act(act): + if act == "hardswish": + return nn.Hardswish() + elif act == "relu": + return nn.ReLU() + elif act is None: + return Identity() + else: + raise RuntimeError( + "The activation function is not supported: {}".format(act)) + + class ConvBNLayer(TheseusLayer): def __init__(self, num_channels, @@ -55,7 +66,7 @@ class ConvBNLayer(TheseusLayer): stride=1, groups=1, act="relu"): - super(ConvBNLayer, self).__init__() + super().__init__() self.conv = nn.Conv2D( in_channels=num_channels, @@ -65,10 +76,8 @@ class ConvBNLayer(TheseusLayer): padding=(filter_size - 1) // 2, groups=groups, bias_attr=False) - self.bn = nn.BatchNorm( - num_filters, - act=None) - self.act = create_act(act) + self.bn = nn.BatchNorm(num_filters, act=None) + self.act = _create_act(act) def forward(self, x): x = self.conv(x) @@ -77,18 +86,6 @@ class ConvBNLayer(TheseusLayer): return x -def create_act(act): - if act == 'hardswish': - return nn.Hardswish() - elif act == 'relu': - return nn.ReLU() - elif act is None: - return Identity() - else: - raise RuntimeError( - 'The activation function is not supported: {}'.format(act)) - - class BottleneckBlock(TheseusLayer): def __init__(self, num_channels, @@ -96,7 +93,7 @@ class BottleneckBlock(TheseusLayer): has_se, stride=1, downsample=False): - super(BottleneckBlock, self).__init__() + super().__init__() self.has_se = has_se self.downsample = downsample @@ -147,11 +144,8 @@ class BottleneckBlock(TheseusLayer): class BasicBlock(nn.Layer): - def __init__(self, - num_channels, - num_filters, - has_se=False): - super(BasicBlock, self).__init__() + def __init__(self, num_channels, num_filters, has_se=False): + super().__init__() self.has_se = has_se @@ -190,9 +184,9 @@ class BasicBlock(nn.Layer): class SELayer(TheseusLayer): def __init__(self, num_channels, num_filters, reduction_ratio): - super(SELayer, self).__init__() + super().__init__() - self.pool2d_gap = nn.AdaptiveAvgPool2D(1) + self.avg_pool = nn.AdaptiveAvgPool2D(1) self._num_channels = num_channels @@ -201,8 +195,7 @@ class SELayer(TheseusLayer): self.fc_squeeze = nn.Linear( num_channels, med_ch, - weight_attr=ParamAttr( - initializer=Uniform(-stdv, stdv))) + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) self.relu = nn.ReLU() stdv = 1.0 / math.sqrt(med_ch * 1.0) self.fc_excitation = nn.Linear( @@ -213,7 +206,7 @@ class SELayer(TheseusLayer): def forward(self, x, res_dict=None): residual = x - x = self.pool2d_gap(x) + x = self.avg_pool(x) x = paddle.squeeze(x, axis=[2, 3]) x = self.fc_squeeze(x) x = self.relu(x) @@ -225,11 +218,8 @@ class SELayer(TheseusLayer): class Stage(TheseusLayer): - def __init__(self, - num_modules, - num_filters, - has_se=False): - super(Stage, self).__init__() + def __init__(self, num_modules, num_filters, has_se=False): + super().__init__() self._num_modules = num_modules @@ -237,8 +227,7 @@ class Stage(TheseusLayer): for i in range(num_modules): self.stage_func_list.append( HighResolutionModule( - num_filters=num_filters, - has_se=has_se)) + num_filters=num_filters, has_se=has_se)) def forward(self, x, res_dict=None): x = x @@ -248,10 +237,8 @@ class Stage(TheseusLayer): class HighResolutionModule(TheseusLayer): - def __init__(self, - num_filters, - has_se=False): - super(HighResolutionModule, self).__init__() + def __init__(self, num_filters, has_se=False): + super().__init__() self.basic_block_list = nn.LayerList() @@ -261,11 +248,11 @@ class HighResolutionModule(TheseusLayer): BasicBlock( num_channels=num_filters[i], num_filters=num_filters[i], - has_se=has_se) for j in range(4)])) + has_se=has_se) for j in range(4) + ])) self.fuse_func = FuseLayers( - in_channels=num_filters, - out_channels=num_filters) + in_channels=num_filters, out_channels=num_filters) def forward(self, x, res_dict=None): out = [] @@ -279,10 +266,8 @@ class HighResolutionModule(TheseusLayer): class FuseLayers(TheseusLayer): - def __init__(self, - in_channels, - out_channels): - super(FuseLayers, self).__init__() + def __init__(self, in_channels, out_channels): + super().__init__() self._actual_ch = len(in_channels) self._in_channels = in_channels @@ -352,7 +337,7 @@ class LastClsOut(TheseusLayer): num_channel_list, has_se, num_filters_list=[32, 64, 128, 256]): - super(LastClsOut, self).__init__() + super().__init__() self.func_list = nn.LayerList() for idx in range(len(num_channel_list)): @@ -378,9 +363,12 @@ class HRNet(TheseusLayer): width: int=18. Base channel number of HRNet. has_se: bool=False. If 'True', add se module to HRNet. class_num: int=1000. Output num of last fc layer. + Returns: + model: nn.Layer. Specific HRNet model depends on args. """ + def __init__(self, width=18, has_se=False, class_num=1000): - super(HRNet, self).__init__() + super().__init__() self.width = width self.has_se = has_se @@ -388,21 +376,23 @@ class HRNet(TheseusLayer): channels_2 = [self.width, self.width * 2] channels_3 = [self.width, self.width * 2, self.width * 4] - channels_4 = [self.width, self.width * 2, self.width * 4, self.width * 8] + channels_4 = [ + self.width, self.width * 2, self.width * 4, self.width * 8 + ] self.conv_layer1_1 = ConvBNLayer( num_channels=3, num_filters=64, filter_size=3, stride=2, - act='relu') + act="relu") self.conv_layer1_2 = ConvBNLayer( num_channels=64, num_filters=64, filter_size=3, stride=2, - act='relu') + act="relu") self.layer1 = nn.Sequential(*[ BottleneckBlock( @@ -410,48 +400,33 @@ class HRNet(TheseusLayer): num_filters=64, has_se=has_se, stride=1, - downsample=True if i == 0 else False) - for i in range(4) + downsample=True if i == 0 else False) for i in range(4) ]) self.conv_tr1_1 = ConvBNLayer( - num_channels=256, - num_filters=width, - filter_size=3) + num_channels=256, num_filters=width, filter_size=3) self.conv_tr1_2 = ConvBNLayer( - num_channels=256, - num_filters=width * 2, - filter_size=3, - stride=2 - ) + num_channels=256, num_filters=width * 2, filter_size=3, stride=2) self.st2 = Stage( - num_modules=1, - num_filters=channels_2, - has_se=self.has_se) + num_modules=1, num_filters=channels_2, has_se=self.has_se) self.conv_tr2 = ConvBNLayer( num_channels=width * 2, num_filters=width * 4, filter_size=3, - stride=2 - ) + stride=2) self.st3 = Stage( - num_modules=4, - num_filters=channels_3, - has_se=self.has_se) + num_modules=4, num_filters=channels_3, has_se=self.has_se) self.conv_tr3 = ConvBNLayer( num_channels=width * 4, num_filters=width * 8, filter_size=3, - stride=2 - ) + stride=2) self.st4 = Stage( - num_modules=3, - num_filters=channels_4, - has_se=self.has_se) + num_modules=3, num_filters=channels_4, has_se=self.has_se) # classification num_filters_list = [32, 64, 128, 256] @@ -464,17 +439,14 @@ class HRNet(TheseusLayer): self.cls_head_conv_list = nn.LayerList() for idx in range(3): self.cls_head_conv_list.append( - ConvBNLayer( - num_channels=num_filters_list[idx] * 4, - num_filters=last_num_filters[idx], - filter_size=3, - stride=2)) + ConvBNLayer( + num_channels=num_filters_list[idx] * 4, + num_filters=last_num_filters[idx], + filter_size=3, + stride=2)) self.conv_last = ConvBNLayer( - num_channels=1024, - num_filters=2048, - filter_size=1, - stride=1) + num_channels=1024, num_filters=2048, filter_size=1, stride=1) self.avg_pool = nn.AdaptiveAvgPool2D(1) @@ -516,81 +488,254 @@ class HRNet(TheseusLayer): return y -def HRNet_W18_C(**args): - model = HRNet(width=18, **args) +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def HRNet_W18_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W18_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W18_C` model depends on args. + """ + model = HRNet(width=18, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W18_C"], use_ssld) return model -def HRNet_W30_C(**args): - model = HRNet(width=30, **args) +def HRNet_W30_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W30_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W30_C` model depends on args. + """ + model = HRNet(width=30, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W30_C"], use_ssld) return model -def HRNet_W32_C(**args): - model = HRNet(width=32, **args) +def HRNet_W32_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W32_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W32_C` model depends on args. + """ + model = HRNet(width=32, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W32_C"], use_ssld) return model -def HRNet_W40_C(**args): - model = HRNet(width=40, **args) +def HRNet_W40_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W40_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W40_C` model depends on args. + """ + model = HRNet(width=40, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W40_C"], use_ssld) return model -def HRNet_W44_C(**args): - model = HRNet(width=44, **args) +def HRNet_W44_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W44_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W44_C` model depends on args. + """ + model = HRNet(width=44, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W44_C"], use_ssld) return model -def HRNet_W48_C(**args): - model = HRNet(width=48, **args) +def HRNet_W48_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W48_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W48_C` model depends on args. + """ + model = HRNet(width=48, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W48_C"], use_ssld) return model -def HRNet_W60_C(**args): - model = HRNet(width=60, **args) +def HRNet_W60_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W60_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W60_C` model depends on args. + """ + model = HRNet(width=60, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W60_C"], use_ssld) return model -def HRNet_W64_C(**args): - model = HRNet(width=64, **args) +def HRNet_W64_C(pretrained=False, use_ssld=False, **kwargs): + """ + HRNet_W64_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `HRNet_W64_C` model depends on args. + """ + model = HRNet(width=64, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["HRNet_W64_C"], use_ssld) return model -def SE_HRNet_W18_C(**args): - model = HRNet(width=18, has_se=True, **args) +def SE_HRNet_W18_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W18_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W18_C` model depends on args. + """ + model = HRNet(width=18, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W18_C"], use_ssld) return model -def SE_HRNet_W30_C(**args): - model = HRNet(width=30, has_se=True, **args) +def SE_HRNet_W30_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W30_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W30_C` model depends on args. + """ + model = HRNet(width=30, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W30_C"], use_ssld) return model -def SE_HRNet_W32_C(**args): - model = HRNet(width=32, has_se=True, **args) +def SE_HRNet_W32_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W32_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W32_C` model depends on args. + """ + model = HRNet(width=32, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W32_C"], use_ssld) return model -def SE_HRNet_W40_C(**args): - model = HRNet(width=40, has_se=True, **args) +def SE_HRNet_W40_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W40_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W40_C` model depends on args. + """ + model = HRNet(width=40, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W40_C"], use_ssld) return model -def SE_HRNet_W44_C(**args): - model = HRNet(width=44, has_se=True, **args) +def SE_HRNet_W44_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W44_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W44_C` model depends on args. + """ + model = HRNet(width=44, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W44_C"], use_ssld) return model -def SE_HRNet_W48_C(**args): - model = HRNet(width=48, has_se=True, **args) +def SE_HRNet_W48_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W48_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W48_C` model depends on args. + """ + model = HRNet(width=48, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W48_C"], use_ssld) return model -def SE_HRNet_W60_C(**args): - model = HRNet(width=60, has_se=True, **args) +def SE_HRNet_W60_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W60_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W60_C` model depends on args. + """ + model = HRNet(width=60, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W60_C"], use_ssld) return model -def SE_HRNet_W64_C(**args): - model = HRNet(width=64, has_se=True, **args) +def SE_HRNet_W64_C(pretrained=False, use_ssld=False, **kwargs): + """ + SE_HRNet_W64_C + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `SE_HRNet_W64_C` model depends on args. + """ + model = HRNet(width=64, has_se=True, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W64_C"], use_ssld) return model diff --git a/ppcls/arch/backbone/legendary_models/inception_v3.py b/ppcls/arch/backbone/legendary_models/inception_v3.py index f06c265fe1820ba3f735b3752740c5eee18bd419..b6403bbe6af0ffa15eab6a6548e30398a8054a2d 100644 --- a/ppcls/arch/backbone/legendary_models/inception_v3.py +++ b/ppcls/arch/backbone/legendary_models/inception_v3.py @@ -13,39 +13,37 @@ # limitations under the License. from __future__ import absolute_import, division, print_function - +import math import paddle from paddle import ParamAttr import paddle.nn as nn from paddle.nn import Conv2D, BatchNorm, Linear, Dropout from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn.initializer import Uniform -import math from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url - MODEL_URLS = { - "InceptionV3": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/InceptionV3_pretrained.pdparams", + "InceptionV3": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams" } - __all__ = MODEL_URLS.keys() - ''' InceptionV3 config: dict. key: inception blocks of InceptionV3. values: conv num in different blocks. ''' NET_CONFIG = { - 'inception_a':[[192, 256, 288], [32, 64, 64]], - 'inception_b':[288], - 'inception_c':[[768, 768, 768, 768], [128, 160, 160, 192]], - 'inception_d':[768], - 'inception_e':[1280,2048] + "inception_a": [[192, 256, 288], [32, 64, 64]], + "inception_b": [288], + "inception_c": [[768, 768, 768, 768], [128, 160, 160, 192]], + "inception_d": [768], + "inception_e": [1280, 2048] } + class ConvBNLayer(TheseusLayer): def __init__(self, num_channels, @@ -55,7 +53,7 @@ class ConvBNLayer(TheseusLayer): padding=0, groups=1, act="relu"): - super(ConvBNLayer, self).__init__() + super().__init__() self.act = act self.conv = Conv2D( in_channels=num_channels, @@ -65,92 +63,100 @@ class ConvBNLayer(TheseusLayer): padding=padding, groups=groups, bias_attr=False) - self.batch_norm = BatchNorm( - num_filters) + self.bn = BatchNorm(num_filters) self.relu = nn.ReLU() def forward(self, x): x = self.conv(x) - x = self.batch_norm(x) + x = self.bn(x) if self.act: x = self.relu(x) return x + class InceptionStem(TheseusLayer): def __init__(self): - super(InceptionStem, self).__init__() - self.conv_1a_3x3 = ConvBNLayer(num_channels=3, - num_filters=32, - filter_size=3, - stride=2, - act="relu") - self.conv_2a_3x3 = ConvBNLayer(num_channels=32, - num_filters=32, - filter_size=3, - stride=1, - act="relu") - self.conv_2b_3x3 = ConvBNLayer(num_channels=32, - num_filters=64, - filter_size=3, - padding=1, - act="relu") - - self.maxpool = MaxPool2D(kernel_size=3, stride=2, padding=0) - self.conv_3b_1x1 = ConvBNLayer(num_channels=64, - num_filters=80, - filter_size=1, - act="relu") - self.conv_4a_3x3 = ConvBNLayer(num_channels=80, - num_filters=192, - filter_size=3, - act="relu") + super().__init__() + self.conv_1a_3x3 = ConvBNLayer( + num_channels=3, + num_filters=32, + filter_size=3, + stride=2, + act="relu") + self.conv_2a_3x3 = ConvBNLayer( + num_channels=32, + num_filters=32, + filter_size=3, + stride=1, + act="relu") + self.conv_2b_3x3 = ConvBNLayer( + num_channels=32, + num_filters=64, + filter_size=3, + padding=1, + act="relu") + + self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0) + self.conv_3b_1x1 = ConvBNLayer( + num_channels=64, num_filters=80, filter_size=1, act="relu") + self.conv_4a_3x3 = ConvBNLayer( + num_channels=80, num_filters=192, filter_size=3, act="relu") + def forward(self, x): x = self.conv_1a_3x3(x) x = self.conv_2a_3x3(x) x = self.conv_2b_3x3(x) - x = self.maxpool(x) + x = self.max_pool(x) x = self.conv_3b_1x1(x) x = self.conv_4a_3x3(x) - x = self.maxpool(x) + x = self.max_pool(x) return x - + class InceptionA(TheseusLayer): def __init__(self, num_channels, pool_features): - super(InceptionA, self).__init__() - self.branch1x1 = ConvBNLayer(num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch5x5_1 = ConvBNLayer(num_channels=num_channels, - num_filters=48, - filter_size=1, - act="relu") - self.branch5x5_2 = ConvBNLayer(num_channels=48, - num_filters=64, - filter_size=5, - padding=2, - act="relu") - - self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer(num_channels=64, - num_filters=96, - filter_size=3, - padding=1, - act="relu") - self.branch3x3dbl_3 = ConvBNLayer(num_channels=96, - num_filters=96, - filter_size=3, - padding=1, - act="relu") - self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer(num_channels=num_channels, - num_filters=pool_features, - filter_size=1, - act="relu") + super().__init__() + self.branch1x1 = ConvBNLayer( + num_channels=num_channels, + num_filters=64, + filter_size=1, + act="relu") + self.branch5x5_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=48, + filter_size=1, + act="relu") + self.branch5x5_2 = ConvBNLayer( + num_channels=48, + num_filters=64, + filter_size=5, + padding=2, + act="relu") + + self.branch3x3dbl_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=64, + filter_size=1, + act="relu") + self.branch3x3dbl_2 = ConvBNLayer( + num_channels=64, + num_filters=96, + filter_size=3, + padding=1, + act="relu") + self.branch3x3dbl_3 = ConvBNLayer( + num_channels=96, + num_filters=96, + filter_size=3, + padding=1, + act="relu") + self.branch_pool = AvgPool2D( + kernel_size=3, stride=1, padding=1, exclusive=False) + self.branch_pool_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=pool_features, + filter_size=1, + act="relu") def forward(self, x): branch1x1 = self.branch1x1(x) @@ -163,34 +169,39 @@ class InceptionA(TheseusLayer): branch_pool = self.branch_pool(x) branch_pool = self.branch_pool_conv(branch_pool) - x = paddle.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) + x = paddle.concat( + [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) return x - + class InceptionB(TheseusLayer): def __init__(self, num_channels): - super(InceptionB, self).__init__() - self.branch3x3 = ConvBNLayer(num_channels=num_channels, - num_filters=384, - filter_size=3, - stride=2, - act="relu") - self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels, - num_filters=64, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer(num_channels=64, - num_filters=96, - filter_size=3, - padding=1, - act="relu") - self.branch3x3dbl_3 = ConvBNLayer(num_channels=96, - num_filters=96, - filter_size=3, - stride=2, - act="relu") + super().__init__() + self.branch3x3 = ConvBNLayer( + num_channels=num_channels, + num_filters=384, + filter_size=3, + stride=2, + act="relu") + self.branch3x3dbl_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=64, + filter_size=1, + act="relu") + self.branch3x3dbl_2 = ConvBNLayer( + num_channels=64, + num_filters=96, + filter_size=3, + padding=1, + act="relu") + self.branch3x3dbl_3 = ConvBNLayer( + num_channels=96, + num_filters=96, + filter_size=3, + stride=2, + act="relu") self.branch_pool = MaxPool2D(kernel_size=3, stride=2) - + def forward(self, x): branch3x3 = self.branch3x3(x) @@ -204,64 +215,75 @@ class InceptionB(TheseusLayer): return x + class InceptionC(TheseusLayer): def __init__(self, num_channels, channels_7x7): - super(InceptionC, self).__init__() - self.branch1x1 = ConvBNLayer(num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - - - self.branch7x7_1 = ConvBNLayer(num_channels=num_channels, - num_filters=channels_7x7, - filter_size=1, - stride=1, - act="relu") - self.branch7x7_2 = ConvBNLayer(num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(1, 7), - stride=1, - padding=(0, 3), - act="relu") - self.branch7x7_3 = ConvBNLayer(num_channels=channels_7x7, - num_filters=192, - filter_size=(7, 1), - stride=1, - padding=(3, 0), - act="relu") - - self.branch7x7dbl_1 = ConvBNLayer(num_channels=num_channels, - num_filters=channels_7x7, - filter_size=1, - act="relu") - self.branch7x7dbl_2 = ConvBNLayer(num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(7, 1), - padding = (3, 0), - act="relu") - self.branch7x7dbl_3 = ConvBNLayer(num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(1, 7), - padding = (0, 3), - act="relu") - self.branch7x7dbl_4 = ConvBNLayer(num_channels=channels_7x7, - num_filters=channels_7x7, - filter_size=(7, 1), - padding = (3, 0), - act="relu") - self.branch7x7dbl_5 = ConvBNLayer(num_channels=channels_7x7, - num_filters=192, - filter_size=(1, 7), - padding = (0, 3), - act="relu") - - self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer(num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - + super().__init__() + self.branch1x1 = ConvBNLayer( + num_channels=num_channels, + num_filters=192, + filter_size=1, + act="relu") + + self.branch7x7_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=channels_7x7, + filter_size=1, + stride=1, + act="relu") + self.branch7x7_2 = ConvBNLayer( + num_channels=channels_7x7, + num_filters=channels_7x7, + filter_size=(1, 7), + stride=1, + padding=(0, 3), + act="relu") + self.branch7x7_3 = ConvBNLayer( + num_channels=channels_7x7, + num_filters=192, + filter_size=(7, 1), + stride=1, + padding=(3, 0), + act="relu") + + self.branch7x7dbl_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=channels_7x7, + filter_size=1, + act="relu") + self.branch7x7dbl_2 = ConvBNLayer( + num_channels=channels_7x7, + num_filters=channels_7x7, + filter_size=(7, 1), + padding=(3, 0), + act="relu") + self.branch7x7dbl_3 = ConvBNLayer( + num_channels=channels_7x7, + num_filters=channels_7x7, + filter_size=(1, 7), + padding=(0, 3), + act="relu") + self.branch7x7dbl_4 = ConvBNLayer( + num_channels=channels_7x7, + num_filters=channels_7x7, + filter_size=(7, 1), + padding=(3, 0), + act="relu") + self.branch7x7dbl_5 = ConvBNLayer( + num_channels=channels_7x7, + num_filters=192, + filter_size=(1, 7), + padding=(0, 3), + act="relu") + + self.branch_pool = AvgPool2D( + kernel_size=3, stride=1, padding=1, exclusive=False) + self.branch_pool_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=192, + filter_size=1, + act="relu") + def forward(self, x): branch1x1 = self.branch1x1(x) @@ -278,41 +300,49 @@ class InceptionC(TheseusLayer): branch_pool = self.branch_pool(x) branch_pool = self.branch_pool_conv(branch_pool) - x = paddle.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) - + x = paddle.concat( + [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) + return x - + + class InceptionD(TheseusLayer): def __init__(self, num_channels): - super(InceptionD, self).__init__() - self.branch3x3_1 = ConvBNLayer(num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - self.branch3x3_2 = ConvBNLayer(num_channels=192, - num_filters=320, - filter_size=3, - stride=2, - act="relu") - self.branch7x7x3_1 = ConvBNLayer(num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") - self.branch7x7x3_2 = ConvBNLayer(num_channels=192, - num_filters=192, - filter_size=(1, 7), - padding=(0, 3), - act="relu") - self.branch7x7x3_3 = ConvBNLayer(num_channels=192, - num_filters=192, - filter_size=(7, 1), - padding=(3, 0), - act="relu") - self.branch7x7x3_4 = ConvBNLayer(num_channels=192, - num_filters=192, - filter_size=3, - stride=2, - act="relu") + super().__init__() + self.branch3x3_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=192, + filter_size=1, + act="relu") + self.branch3x3_2 = ConvBNLayer( + num_channels=192, + num_filters=320, + filter_size=3, + stride=2, + act="relu") + self.branch7x7x3_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=192, + filter_size=1, + act="relu") + self.branch7x7x3_2 = ConvBNLayer( + num_channels=192, + num_filters=192, + filter_size=(1, 7), + padding=(0, 3), + act="relu") + self.branch7x7x3_3 = ConvBNLayer( + num_channels=192, + num_filters=192, + filter_size=(7, 1), + padding=(3, 0), + act="relu") + self.branch7x7x3_4 = ConvBNLayer( + num_channels=192, + num_filters=192, + filter_size=3, + stride=2, + act="relu") self.branch_pool = MaxPool2D(kernel_size=3, stride=2) def forward(self, x): @@ -325,56 +355,68 @@ class InceptionD(TheseusLayer): branch7x7x3 = self.branch7x7x3_4(branch7x7x3) branch_pool = self.branch_pool(x) - + x = paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1) return x - + + class InceptionE(TheseusLayer): def __init__(self, num_channels): - super(InceptionE, self).__init__() - self.branch1x1 = ConvBNLayer(num_channels=num_channels, - num_filters=320, - filter_size=1, - act="relu") - self.branch3x3_1 = ConvBNLayer(num_channels=num_channels, - num_filters=384, - filter_size=1, - act="relu") - self.branch3x3_2a = ConvBNLayer(num_channels=384, - num_filters=384, - filter_size=(1, 3), - padding=(0, 1), - act="relu") - self.branch3x3_2b = ConvBNLayer(num_channels=384, - num_filters=384, - filter_size=(3, 1), - padding=(1, 0), - act="relu") - - self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels, - num_filters=448, - filter_size=1, - act="relu") - self.branch3x3dbl_2 = ConvBNLayer(num_channels=448, - num_filters=384, - filter_size=3, - padding=1, - act="relu") - self.branch3x3dbl_3a = ConvBNLayer(num_channels=384, - num_filters=384, - filter_size=(1, 3), - padding=(0, 1), - act="relu") - self.branch3x3dbl_3b = ConvBNLayer(num_channels=384, - num_filters=384, - filter_size=(3, 1), - padding=(1, 0), - act="relu") - self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False) - self.branch_pool_conv = ConvBNLayer(num_channels=num_channels, - num_filters=192, - filter_size=1, - act="relu") + super().__init__() + self.branch1x1 = ConvBNLayer( + num_channels=num_channels, + num_filters=320, + filter_size=1, + act="relu") + self.branch3x3_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=384, + filter_size=1, + act="relu") + self.branch3x3_2a = ConvBNLayer( + num_channels=384, + num_filters=384, + filter_size=(1, 3), + padding=(0, 1), + act="relu") + self.branch3x3_2b = ConvBNLayer( + num_channels=384, + num_filters=384, + filter_size=(3, 1), + padding=(1, 0), + act="relu") + + self.branch3x3dbl_1 = ConvBNLayer( + num_channels=num_channels, + num_filters=448, + filter_size=1, + act="relu") + self.branch3x3dbl_2 = ConvBNLayer( + num_channels=448, + num_filters=384, + filter_size=3, + padding=1, + act="relu") + self.branch3x3dbl_3a = ConvBNLayer( + num_channels=384, + num_filters=384, + filter_size=(1, 3), + padding=(0, 1), + act="relu") + self.branch3x3dbl_3b = ConvBNLayer( + num_channels=384, + num_filters=384, + filter_size=(3, 1), + padding=(1, 0), + act="relu") + self.branch_pool = AvgPool2D( + kernel_size=3, stride=1, padding=1, exclusive=False) + self.branch_pool_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=192, + filter_size=1, + act="relu") + def forward(self, x): branch1x1 = self.branch1x1(x) @@ -396,8 +438,9 @@ class InceptionE(TheseusLayer): branch_pool = self.branch_pool(x) branch_pool = self.branch_pool_conv(branch_pool) - x = paddle.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) - return x + x = paddle.concat( + [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) + return x class Inception_V3(TheseusLayer): @@ -410,25 +453,21 @@ class Inception_V3(TheseusLayer): Returns: model: nn.Layer. Specific Inception_V3 model depends on args. """ - def __init__(self, - config, - class_num=1000, - pretrained=False, - **kwargs): - super(Inception_V3, self).__init__() - - self.inception_a_list = config['inception_a'] - self.inception_c_list = config['inception_c'] - self.inception_b_list = config['inception_b'] - self.inception_d_list = config['inception_d'] - self.inception_e_list = config ['inception_e'] - self.pretrained = pretrained + + def __init__(self, config, class_num=1000): + super().__init__() + + self.inception_a_list = config["inception_a"] + self.inception_c_list = config["inception_c"] + self.inception_b_list = config["inception_b"] + self.inception_d_list = config["inception_d"] + self.inception_e_list = config["inception_e"] self.inception_stem = InceptionStem() self.inception_block_list = nn.LayerList() for i in range(len(self.inception_a_list[0])): - inception_a = InceptionA(self.inception_a_list[0][i], + inception_a = InceptionA(self.inception_a_list[0][i], self.inception_a_list[1][i]) self.inception_block_list.append(inception_a) @@ -437,7 +476,7 @@ class Inception_V3(TheseusLayer): self.inception_block_list.append(inception_b) for i in range(len(self.inception_c_list[0])): - inception_c = InceptionC(self.inception_c_list[0][i], + inception_c = InceptionC(self.inception_c_list[0][i], self.inception_c_list[1][i]) self.inception_block_list.append(inception_c) @@ -448,21 +487,20 @@ class Inception_V3(TheseusLayer): for i in range(len(self.inception_e_list)): inception_e = InceptionE(self.inception_e_list[i]) self.inception_block_list.append(inception_e) - + self.avg_pool = AdaptiveAvgPool2D(1) self.dropout = Dropout(p=0.2, mode="downscale_in_infer") stdv = 1.0 / math.sqrt(2048 * 1.0) self.fc = Linear( 2048, class_num, - weight_attr=ParamAttr( - initializer=Uniform(-stdv, stdv)), + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), bias_attr=ParamAttr()) def forward(self, x): x = self.inception_stem(x) for inception_block in self.inception_block_list: - x = inception_block(x) + x = inception_block(x) x = self.avg_pool(x) x = paddle.reshape(x, shape=[-1, 2048]) x = self.dropout(x) @@ -470,25 +508,29 @@ class Inception_V3(TheseusLayer): return x -def InceptionV3(**kwargs): +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def InceptionV3(pretrained=False, use_ssld=False, **kwargs): """ InceptionV3 Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise. + if str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `InceptionV3` model """ model = Inception_V3(NET_CONFIG, **kwargs) - - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["InceptionV3"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + _load_pretrained(pretrained, model, MODEL_URLS["InceptionV3"], use_ssld) return model - diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py index cf57f9882745e438a6d31ff7bb70e5ea7c282725..3a14dc81d93d3d8e0a6bc3f21323bfe94f702d28 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py @@ -14,8 +14,6 @@ from __future__ import absolute_import, division, print_function -import numpy as np -import paddle from paddle import ParamAttr import paddle.nn as nn from paddle.nn import Conv2D, BatchNorm, Linear, ReLU, Flatten @@ -23,19 +21,22 @@ from paddle.nn import AdaptiveAvgPool2D from paddle.nn.initializer import KaimingNormal from ppcls.arch.backbone.base.theseus_layer import TheseusLayer -from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url - +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url MODEL_URLS = { - "MobileNetV1_x0_25": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_25_pretrained.pdparams", - "MobileNetV1_x0_5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_5_pretrained.pdparams", - "MobileNetV1_x0_75": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_75_pretrained.pdparams", - "MobileNetV1": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_pretrained.pdparams", + "MobileNetV1_x0_25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_25_pretrained.pdparams", + "MobileNetV1_x0_5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_5_pretrained.pdparams", + "MobileNetV1_x0_75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_75_pretrained.pdparams", + "MobileNetV1": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_pretrained.pdparams" } __all__ = MODEL_URLS.keys() - - + + class ConvBNLayer(TheseusLayer): def __init__(self, num_channels, @@ -44,7 +45,7 @@ class ConvBNLayer(TheseusLayer): stride, padding, num_groups=1): - super(ConvBNLayer, self).__init__() + super().__init__() self.conv = Conv2D( in_channels=num_channels, @@ -55,9 +56,7 @@ class ConvBNLayer(TheseusLayer): groups=num_groups, weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False) - self.bn = BatchNorm(num_filters) - self.relu = ReLU() def forward(self, x): @@ -68,14 +67,9 @@ class ConvBNLayer(TheseusLayer): class DepthwiseSeparable(TheseusLayer): - def __init__(self, - num_channels, - num_filters1, - num_filters2, - num_groups, - stride, - scale): - super(DepthwiseSeparable, self).__init__() + def __init__(self, num_channels, num_filters1, num_filters2, num_groups, + stride, scale): + super().__init__() self.depthwise_conv = ConvBNLayer( num_channels=num_channels, @@ -99,10 +93,18 @@ class DepthwiseSeparable(TheseusLayer): class MobileNet(TheseusLayer): - def __init__(self, scale=1.0, class_num=1000, pretrained=False): - super(MobileNet, self).__init__() + """ + MobileNet + Args: + scale: float=1.0. The coefficient that controls the size of network parameters. + class_num: int=1000. The number of classes. + Returns: + model: nn.Layer. Specific MobileNet model depends on args. + """ + + def __init__(self, scale=1.0, class_num=1000): + super().__init__() self.scale = scale - self.pretrained = pretrained self.conv = ConvBNLayer( num_channels=3, @@ -110,30 +112,31 @@ class MobileNet(TheseusLayer): num_filters=int(32 * scale), stride=2, padding=1) - + #num_channels, num_filters1, num_filters2, num_groups, stride - self.cfg = [[int(32 * scale), 32, 64, 32, 1], - [int(64 * scale), 64, 128, 64, 2], - [int(128 * scale), 128, 128, 128, 1], - [int(128 * scale), 128, 256, 128, 2], - [int(256 * scale), 256, 256, 256, 1], - [int(256 * scale), 256, 512, 256, 2], - [int(512 * scale), 512, 512, 512, 1], - [int(512 * scale), 512, 512, 512, 1], - [int(512 * scale), 512, 512, 512, 1], - [int(512 * scale), 512, 512, 512, 1], - [int(512 * scale), 512, 512, 512, 1], - [int(512 * scale), 512, 1024, 512, 2], + self.cfg = [[int(32 * scale), 32, 64, 32, 1], + [int(64 * scale), 64, 128, 64, 2], + [int(128 * scale), 128, 128, 128, 1], + [int(128 * scale), 128, 256, 128, 2], + [int(256 * scale), 256, 256, 256, 1], + [int(256 * scale), 256, 512, 256, 2], + [int(512 * scale), 512, 512, 512, 1], + [int(512 * scale), 512, 512, 512, 1], + [int(512 * scale), 512, 512, 512, 1], + [int(512 * scale), 512, 512, 512, 1], + [int(512 * scale), 512, 512, 512, 1], + [int(512 * scale), 512, 1024, 512, 2], [int(1024 * scale), 1024, 1024, 1024, 1]] - + self.blocks = nn.Sequential(*[ - DepthwiseSeparable( - num_channels=params[0], - num_filters1=params[1], - num_filters2=params[2], - num_groups=params[3], - stride=params[4], - scale=scale) for params in self.cfg]) + DepthwiseSeparable( + num_channels=params[0], + num_filters1=params[1], + num_filters2=params[2], + num_groups=params[3], + stride=params[4], + scale=scale) for params in self.cfg + ]) self.avg_pool = AdaptiveAvgPool2D(1) self.flatten = Flatten(start_axis=1, stop_axis=-1) @@ -142,7 +145,7 @@ class MobileNet(TheseusLayer): int(1024 * scale), class_num, weight_attr=ParamAttr(initializer=KaimingNormal())) - + def forward(self, x): x = self.conv(x) x = self.blocks(x) @@ -152,91 +155,77 @@ class MobileNet(TheseusLayer): return x -def MobileNetV1_x0_25(**args): - """ - MobileNetV1_x0_25 - Args: - pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise. - kwargs: - class_num: int=1000. Output dim of last fc layer. - Returns: - model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args. - """ - model = MobileNet(scale=0.25, **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_25"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) else: raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") - return model + "pretrained type is not available. Please use `string` or `boolean` type." + ) -def MobileNetV1_x0_5(**args): +def MobileNetV1_x0_25(pretrained=False, use_ssld=False, **kwargs): """ - MobileNetV1_x0_5 - Args: - pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise. - kwargs: - class_num: int=1000. Output dim of last fc layer. - Returns: - model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args. + MobileNetV1_x0_25 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args. """ - model = MobileNet(scale=0.5, **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_5"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = MobileNet(scale=0.25, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_25"], + use_ssld) return model -def MobileNetV1_x0_75(**args): +def MobileNetV1_x0_5(pretrained=False, use_ssld=False, **kwargs): """ - MobileNetV1_x0_75 - Args: - pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise. - kwargs: - class_num: int=1000. Output dim of last fc layer. - Returns: - model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args. + MobileNetV1_x0_5 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args. """ - model = MobileNet(scale=0.75, **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_75"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = MobileNet(scale=0.5, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_5"], + use_ssld) return model -def MobileNetV1(**args): +def MobileNetV1_x0_75(pretrained=False, use_ssld=False, **kwargs): """ - MobileNetV1 - Args: - pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise. - kwargs: - class_num: int=1000. Output dim of last fc layer. - Returns: - model: nn.Layer. Specific `MobileNetV1` model depends on args. + MobileNetV1_x0_75 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args. """ - model = MobileNet(scale=1.0, **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = MobileNet(scale=0.75, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_75"], + use_ssld) return model +def MobileNetV1(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV1 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV1` model depends on args. + """ + model = MobileNet(scale=1.0, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1"], use_ssld) + return model diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v3.py b/ppcls/arch/backbone/legendary_models/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..aff69bcae1d5a67d5c14bb95c39af3ecca6e48a3 --- /dev/null +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v3.py @@ -0,0 +1,557 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "MobileNetV3_small_x0_35": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_35_pretrained.pdparams", + "MobileNetV3_small_x0_5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_5_pretrained.pdparams", + "MobileNetV3_small_x0_75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_75_pretrained.pdparams", + "MobileNetV3_small_x1_0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_pretrained.pdparams", + "MobileNetV3_small_x1_25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_25_pretrained.pdparams", + "MobileNetV3_large_x0_35": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_35_pretrained.pdparams", + "MobileNetV3_large_x0_5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_5_pretrained.pdparams", + "MobileNetV3_large_x0_75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_75_pretrained.pdparams", + "MobileNetV3_large_x1_0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_0_pretrained.pdparams", + "MobileNetV3_large_x1_25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_25_pretrained.pdparams", +} + +__all__ = MODEL_URLS.keys() + +# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively. +# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s. +# k: kernel_size +# exp: middle channel number in depthwise block +# c: output channel number in depthwise block +# se: whether to use SE block +# act: which activation to use +# s: stride in depthwise block +NET_CONFIG = { + "large": [ + # k, exp, c, se, act, s + [3, 16, 16, False, "relu", 1], + [3, 64, 24, False, "relu", 2], + [3, 72, 24, False, "relu", 1], + [5, 72, 40, True, "relu", 2], + [5, 120, 40, True, "relu", 1], + [5, 120, 40, True, "relu", 1], + [3, 240, 80, False, "hardswish", 2], + [3, 200, 80, False, "hardswish", 1], + [3, 184, 80, False, "hardswish", 1], + [3, 184, 80, False, "hardswish", 1], + [3, 480, 112, True, "hardswish", 1], + [3, 672, 112, True, "hardswish", 1], + [5, 672, 160, True, "hardswish", 2], + [5, 960, 160, True, "hardswish", 1], + [5, 960, 160, True, "hardswish", 1], + ], + "small": [ + # k, exp, c, se, act, s + [3, 16, 16, True, "relu", 2], + [3, 72, 24, False, "relu", 2], + [3, 88, 24, False, "relu", 1], + [5, 96, 40, True, "hardswish", 2], + [5, 240, 40, True, "hardswish", 1], + [5, 240, 40, True, "hardswish", 1], + [5, 120, 48, True, "hardswish", 1], + [5, 144, 48, True, "hardswish", 1], + [5, 288, 96, True, "hardswish", 2], + [5, 576, 96, True, "hardswish", 1], + [5, 576, 96, True, "hardswish", 1], + ] +} +# first conv output channel number in MobileNetV3 +STEM_CONV_NUMBER = 16 +# last second conv output channel for "small" +LAST_SECOND_CONV_SMALL = 576 +# last second conv output channel for "large" +LAST_SECOND_CONV_LARGE = 960 +# last conv output channel number for "large" and "small" +LAST_CONV = 1280 + + +def _make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def _create_act(act): + if act == "hardswish": + return nn.Hardswish() + elif act == "relu": + return nn.ReLU() + elif act is None: + return None + else: + raise RuntimeError( + "The activation function is not supported: {}".format(act)) + + +class MobileNetV3(TheseusLayer): + """ + MobileNetV3 + Args: + config: list. MobileNetV3 depthwise blocks config. + scale: float=1.0. The coefficient that controls the size of network parameters. + class_num: int=1000. The number of classes. + inplanes: int=16. The output channel number of first convolution layer. + class_squeeze: int=960. The output channel number of penultimate convolution layer. + class_expand: int=1280. The output channel number of last convolution layer. + dropout_prob: float=0.2. Probability of setting units to zero. + Returns: + model: nn.Layer. Specific MobileNetV3 model depends on args. + """ + + def __init__(self, + config, + scale=1.0, + class_num=1000, + inplanes=STEM_CONV_NUMBER, + class_squeeze=LAST_SECOND_CONV_LARGE, + class_expand=LAST_CONV, + dropout_prob=0.2): + super().__init__() + + self.cfg = config + self.scale = scale + self.inplanes = inplanes + self.class_squeeze = class_squeeze + self.class_expand = class_expand + self.class_num = class_num + + self.conv = ConvBNLayer( + in_c=3, + out_c=_make_divisible(self.inplanes * self.scale), + filter_size=3, + stride=2, + padding=1, + num_groups=1, + if_act=True, + act="hardswish") + + self.blocks = nn.Sequential(*[ + ResidualUnit( + in_c=_make_divisible(self.inplanes * self.scale if i == 0 else + self.cfg[i - 1][2] * self.scale), + mid_c=_make_divisible(self.scale * exp), + out_c=_make_divisible(self.scale * c), + filter_size=k, + stride=s, + use_se=se, + act=act) for i, (k, exp, c, se, act, s) in enumerate(self.cfg) + ]) + + self.last_second_conv = ConvBNLayer( + in_c=_make_divisible(self.cfg[-1][2] * self.scale), + out_c=_make_divisible(self.scale * self.class_squeeze), + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + act="hardswish") + + self.avg_pool = AdaptiveAvgPool2D(1) + + self.last_conv = Conv2D( + in_channels=_make_divisible(self.scale * self.class_squeeze), + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + + self.hardswish = nn.Hardswish() + self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") + self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) + + self.fc = Linear(self.class_expand, class_num) + + def forward(self, x): + x = self.conv(x) + x = self.blocks(x) + x = self.last_second_conv(x) + x = self.avg_pool(x) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + + return x + + +class ConvBNLayer(TheseusLayer): + def __init__(self, + in_c, + out_c, + filter_size, + stride, + padding, + num_groups=1, + if_act=True, + act=None): + super().__init__() + + self.conv = Conv2D( + in_channels=in_c, + out_channels=out_c, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + bias_attr=False) + self.bn = BatchNorm( + num_channels=out_c, + act=None, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.if_act = if_act + self.act = _create_act(act) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + x = self.act(x) + return x + + +class ResidualUnit(TheseusLayer): + def __init__(self, + in_c, + mid_c, + out_c, + filter_size, + stride, + use_se, + act=None): + super().__init__() + self.if_shortcut = stride == 1 and in_c == out_c + self.if_se = use_se + + self.expand_conv = ConvBNLayer( + in_c=in_c, + out_c=mid_c, + filter_size=1, + stride=1, + padding=0, + if_act=True, + act=act) + self.bottleneck_conv = ConvBNLayer( + in_c=mid_c, + out_c=mid_c, + filter_size=filter_size, + stride=stride, + padding=int((filter_size - 1) // 2), + num_groups=mid_c, + if_act=True, + act=act) + if self.if_se: + self.mid_se = SEModule(mid_c) + self.linear_conv = ConvBNLayer( + in_c=mid_c, + out_c=out_c, + filter_size=1, + stride=1, + padding=0, + if_act=False, + act=None) + + def forward(self, x): + identity = x + x = self.expand_conv(x) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = paddle.add(identity, x) + return x + + +# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid +class Hardsigmoid(TheseusLayer): + def __init__(self, slope=0.2, offset=0.5): + super().__init__() + self.slope = slope + self.offset = offset + + def forward(self, x): + return nn.functional.hardsigmoid( + x, slope=self.slope, offset=self.offset) + + +class SEModule(TheseusLayer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5) + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + return paddle.multiply(x=identity, y=x) + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def MobileNetV3_small_x0_35(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_small_x0_35 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=0.35, + class_squeeze=LAST_SECOND_CONV_SMALL, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_35"], + use_ssld) + return model + + +def MobileNetV3_small_x0_5(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_small_x0_5 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=0.5, + class_squeeze=LAST_SECOND_CONV_SMALL, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_5"], + use_ssld) + return model + + +def MobileNetV3_small_x0_75(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_small_x0_75 + Args: + pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise. + if str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=0.75, + class_squeeze=LAST_SECOND_CONV_SMALL, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_75"], + use_ssld) + return model + + +def MobileNetV3_small_x1_0(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_small_x1_0 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=1.0, + class_squeeze=LAST_SECOND_CONV_SMALL, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_0"], + use_ssld) + return model + + +def MobileNetV3_small_x1_25(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_small_x1_25 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=1.25, + class_squeeze=LAST_SECOND_CONV_SMALL, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_25"], + use_ssld) + return model + + +def MobileNetV3_large_x0_35(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_large_x0_35 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=0.35, + class_squeeze=LAST_SECOND_CONV_LARGE, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_35"], + use_ssld) + return model + + +def MobileNetV3_large_x0_5(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_large_x0_5 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=0.5, + class_squeeze=LAST_SECOND_CONV_LARGE, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_5"], + use_ssld) + return model + + +def MobileNetV3_large_x0_75(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_large_x0_75 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=0.75, + class_squeeze=LAST_SECOND_CONV_LARGE, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_75"], + use_ssld) + return model + + +def MobileNetV3_large_x1_0(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_large_x1_0 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=1.0, + class_squeeze=LAST_SECOND_CONV_LARGE, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_0"], + use_ssld) + return model + + +def MobileNetV3_large_x1_25(pretrained=False, use_ssld=False, **kwargs): + """ + MobileNetV3_large_x1_25 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args. + """ + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=1.25, + class_squeeze=LAST_SECOND_CONV_LARGE, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_25"], + use_ssld) + return model diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 1992504a2f7b7b466aa6b7d3013ca8c6e17d80bd..5d107fe242039565ebfa1b21940779d8dd8a26af 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -24,26 +24,34 @@ from paddle.nn.initializer import Uniform import math from ppcls.arch.backbone.base.theseus_layer import TheseusLayer -from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url - +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url MODEL_URLS = { - "ResNet18": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams", - "ResNet18_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams", - "ResNet34": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams", - "ResNet34_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_vd_pretrained.pdparams", - "ResNet50": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams", - "ResNet50_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams", - "ResNet101": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_pretrained.pdparams", - "ResNet101_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_vd_pretrained.pdparams", - "ResNet152": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_pretrained.pdparams", - "ResNet152_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_vd_pretrained.pdparams", - "ResNet200_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet200_vd_pretrained.pdparams", + "ResNet18": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams", + "ResNet18_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams", + "ResNet34": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams", + "ResNet34_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams", + "ResNet50": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams", + "ResNet50_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams", + "ResNet101": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams", + "ResNet101_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams", + "ResNet152": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams", + "ResNet152_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams", + "ResNet200_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams", } __all__ = MODEL_URLS.keys() - - ''' ResNet config: dict. key: depth of ResNet. @@ -55,17 +63,35 @@ ResNet config: dict. ''' NET_CONFIG = { "18": { - "block_type": "BasicBlock", "block_depth": [2, 2, 2, 2], "num_channels": [64, 64, 128, 256]}, + "block_type": "BasicBlock", + "block_depth": [2, 2, 2, 2], + "num_channels": [64, 64, 128, 256] + }, "34": { - "block_type": "BasicBlock", "block_depth": [3, 4, 6, 3], "num_channels": [64, 64, 128, 256]}, + "block_type": "BasicBlock", + "block_depth": [3, 4, 6, 3], + "num_channels": [64, 64, 128, 256] + }, "50": { - "block_type": "BottleneckBlock", "block_depth": [3, 4, 6, 3], "num_channels": [64, 256, 512, 1024]}, + "block_type": "BottleneckBlock", + "block_depth": [3, 4, 6, 3], + "num_channels": [64, 256, 512, 1024] + }, "101": { - "block_type": "BottleneckBlock", "block_depth": [3, 4, 23, 3], "num_channels": [64, 256, 512, 1024]}, + "block_type": "BottleneckBlock", + "block_depth": [3, 4, 23, 3], + "num_channels": [64, 256, 512, 1024] + }, "152": { - "block_type": "BottleneckBlock", "block_depth": [3, 8, 36, 3], "num_channels": [64, 256, 512, 1024]}, + "block_type": "BottleneckBlock", + "block_depth": [3, 8, 36, 3], + "num_channels": [64, 256, 512, 1024] + }, "200": { - "block_type": "BottleneckBlock", "block_depth": [3, 12, 48, 3], "num_channels": [64, 256, 512, 1024]}, + "block_type": "BottleneckBlock", + "block_depth": [3, 12, 48, 3], + "num_channels": [64, 256, 512, 1024] + }, } @@ -110,14 +136,14 @@ class ConvBNLayer(TheseusLayer): class BottleneckBlock(TheseusLayer): - def __init__(self, - num_channels, - num_filters, - stride, - shortcut=True, - if_first=False, - lr_mult=1.0, - ): + def __init__( + self, + num_channels, + num_filters, + stride, + shortcut=True, + if_first=False, + lr_mult=1.0, ): super().__init__() self.conv0 = ConvBNLayer( @@ -222,16 +248,15 @@ class ResNet(TheseusLayer): version: str="vb". Different version of ResNet, version vd can perform better. class_num: int=1000. The number of classes. lr_mult_list: list. Control the learning rate of different stages. - pretrained: (True or False) or path of pretrained_model. Whether to load the pretrained model. Returns: model: nn.Layer. Specific ResNet model depends on args. """ + def __init__(self, config, version="vb", class_num=1000, - lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], - pretrained=False): + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]): super().__init__() self.cfg = config @@ -243,51 +268,46 @@ class ResNet(TheseusLayer): self.block_type = self.cfg["block_type"] self.num_channels = self.cfg["num_channels"] self.channels_mult = 1 if self.num_channels[-1] == 256 else 4 - self.pretrained = pretrained - + assert isinstance(self.lr_mult_list, ( list, tuple )), "lr_mult_list should be in (list, tuple) but got {}".format( type(self.lr_mult_list)) - assert len( - self.lr_mult_list - ) == 5, "lr_mult_list length should be 5 but got {}".format( - len(self.lr_mult_list)) - + assert len(self.lr_mult_list + ) == 5, "lr_mult_list length should be 5 but got {}".format( + len(self.lr_mult_list)) self.stem_cfg = { #num_channels, num_filters, filter_size, stride "vb": [[3, 64, 7, 2]], - "vd": [[3, 32, 3, 2], - [32, 32, 3, 1], - [32, 64, 3, 1]]} - + "vd": [[3, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]] + } + self.stem = nn.Sequential(*[ ConvBNLayer( - num_channels=in_c, - num_filters=out_c, - filter_size=k, - stride=s, - act="relu", - lr_mult=self.lr_mult_list[0]) + num_channels=in_c, + num_filters=out_c, + filter_size=k, + stride=s, + act="relu", + lr_mult=self.lr_mult_list[0]) for in_c, out_c, k, s in self.stem_cfg[version] ]) - + self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) block_list = [] for block_idx in range(len(self.block_depth)): shortcut = False for i in range(self.block_depth[block_idx]): - block_list.append( - globals()[self.block_type]( - num_channels=self.num_channels[block_idx] - if i == 0 else self.num_filters[block_idx] * self.channels_mult, + block_list.append(globals()[self.block_type]( + num_channels=self.num_channels[block_idx] if i == 0 else + self.num_filters[block_idx] * self.channels_mult, num_filters=self.num_filters[block_idx], stride=2 if i == 0 and block_idx != 0 else 1, shortcut=shortcut, if_first=block_idx == i == 0 if version == "vd" else True, lr_mult=self.lr_mult_list[block_idx + 1])) - shortcut = True + shortcut = True self.blocks = nn.Sequential(*block_list) self.avg_pool = AdaptiveAvgPool2D(1) @@ -297,8 +317,7 @@ class ResNet(TheseusLayer): self.fc = Linear( self.avg_pool_channels, self.class_num, - weight_attr=ParamAttr( - initializer=Uniform(-stdv, stdv))) + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) def forward(self, x): x = self.stem(x) @@ -310,254 +329,179 @@ class ResNet(TheseusLayer): return x -def ResNet18(**args): +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def ResNet18(pretrained=False, use_ssld=False, **kwargs): """ ResNet18 Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet18` model depends on args. """ - model = ResNet(config=NET_CONFIG["18"], version="vb", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld) return model -def ResNet18_vd(**args): +def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs): """ ResNet18_vd Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet18_vd` model depends on args. """ - model = ResNet(config=NET_CONFIG["18"], version="vd", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18_vd"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld) return model -def ResNet34(**args): +def ResNet34(pretrained=False, use_ssld=False, **kwargs): """ ResNet34 Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet34` model depends on args. """ - model = ResNet(config=NET_CONFIG["34"], version="vb", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld) return model -def ResNet34_vd(**args): +def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs): """ ResNet34_vd Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet34_vd` model depends on args. """ - model = ResNet(config=NET_CONFIG["34"], version="vd", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34_vd"], use_ssld=True) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld) return model -def ResNet50(**args): +def ResNet50(pretrained=False, use_ssld=False, **kwargs): """ ResNet50 Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet50` model depends on args. """ - model = ResNet(config=NET_CONFIG["50"], version="vb", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld) return model -def ResNet50_vd(**args): +def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs): """ ResNet50_vd Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet50_vd` model depends on args. """ - model = ResNet(config=NET_CONFIG["50"], version="vd", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50_vd"], use_ssld=True) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld) return model -def ResNet101(**args): +def ResNet101(pretrained=False, use_ssld=False, **kwargs): """ ResNet101 Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet101` model depends on args. """ - model = ResNet(config=NET_CONFIG["101"], version="vb", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld) return model -def ResNet101_vd(**args): +def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs): """ ResNet101_vd Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet101_vd` model depends on args. """ - model = ResNet(config=NET_CONFIG["101"], version="vd", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101_vd"], use_ssld=True) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld) return model -def ResNet152(**args): +def ResNet152(pretrained=False, use_ssld=False, **kwargs): """ ResNet152 Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet152` model depends on args. """ - model = ResNet(config=NET_CONFIG["152"], version="vb", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld) return model -def ResNet152_vd(**args): +def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs): """ ResNet152_vd Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet152_vd` model depends on args. """ - model = ResNet(config=NET_CONFIG["152"], version="vd", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152_vd"]) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld) return model -def ResNet200_vd(**args): +def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs): """ ResNet200_vd Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages. - pretrained: bool or str, default: bool=False. Whether to load the pretrained model. + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. Returns: model: nn.Layer. Specific `ResNet200_vd` model depends on args. """ - model = ResNet(config=NET_CONFIG["200"], version="vd", **args) - if isinstance(model.pretrained, bool): - if model.pretrained is True: - load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet200_vd"], use_ssld=True) - elif isinstance(model.pretrained, str): - load_dygraph_pretrain(model, model.pretrained) - else: - raise RuntimeError( - "pretrained type is not available. Please use `string` or `boolean` type") + model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld) return model diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index b127dd2747f35b39ec485403edc6f18ab869da59..7868b51eafce4f0bd383ad66199e50f2a05c1832 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -14,16 +14,24 @@ from __future__ import absolute_import, division, print_function -import paddle -from paddle import ParamAttr import paddle.nn as nn from paddle.nn import Conv2D, BatchNorm, Linear, Dropout from paddle.nn import MaxPool2D from ppcls.arch.backbone.base.theseus_layer import TheseusLayer -from ppcls.utils.save_load import load_dygraph_pretrain - -__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"] +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "VGG11": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG11_pretrained.pdparams", + "VGG13": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG13_pretrained.pdparams", + "VGG16": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG16_pretrained.pdparams", + "VGG19": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG19_pretrained.pdparams", +} +__all__ = MODEL_URLS.keys() # VGG config # key: VGG network depth @@ -36,68 +44,12 @@ NET_CONFIG = { } -def VGG11(**args): - """ - VGG11 - Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False` - Returns: - model: nn.Layer. Specific `VGG11` model depends on args. - """ - model = VGGNet(config=NET_CONFIG[11], **args) - return model - - -def VGG13(**args): - """ - VGG13 - Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False` - Returns: - model: nn.Layer. Specific `VGG11` model depends on args. - """ - model = VGGNet(config=NET_CONFIG[13], **args) - return model - - -def VGG16(**args): - """ - VGG16 - Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False` - Returns: - model: nn.Layer. Specific `VGG11` model depends on args. - """ - model = VGGNet(config=NET_CONFIG[16], **args) - return model - - -def VGG19(**args): - """ - VGG19 - Args: - kwargs: - class_num: int=1000. Output dim of last fc layer. - stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False` - Returns: - model: nn.Layer. Specific `VGG11` model depends on args. - """ - model = VGGNet(config=NET_CONFIG[19], **args) - return model - - class ConvBlock(TheseusLayer): def __init__(self, input_channels, output_channels, groups): - super(ConvBlock, self).__init__() + super().__init__() self.groups = groups - self._conv_1 = Conv2D( + self.conv1 = Conv2D( in_channels=input_channels, out_channels=output_channels, kernel_size=3, @@ -105,7 +57,7 @@ class ConvBlock(TheseusLayer): padding=1, bias_attr=False) if groups == 2 or groups == 3 or groups == 4: - self._conv_2 = Conv2D( + self.conv2 = Conv2D( in_channels=output_channels, out_channels=output_channels, kernel_size=3, @@ -113,7 +65,7 @@ class ConvBlock(TheseusLayer): padding=1, bias_attr=False) if groups == 3 or groups == 4: - self._conv_3 = Conv2D( + self.conv3 = Conv2D( in_channels=output_channels, out_channels=output_channels, kernel_size=3, @@ -121,7 +73,7 @@ class ConvBlock(TheseusLayer): padding=1, bias_attr=False) if groups == 4: - self._conv_4 = Conv2D( + self.conv4 = Conv2D( in_channels=output_channels, out_channels=output_channels, kernel_size=3, @@ -129,73 +81,148 @@ class ConvBlock(TheseusLayer): padding=1, bias_attr=False) - self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0) - self._relu = nn.ReLU() + self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0) + self.relu = nn.ReLU() def forward(self, inputs): - x = self._conv_1(inputs) - x = self._relu(x) + x = self.conv1(inputs) + x = self.relu(x) if self.groups == 2 or self.groups == 3 or self.groups == 4: - x = self._conv_2(x) - x = self._relu(x) + x = self.conv2(x) + x = self.relu(x) if self.groups == 3 or self.groups == 4: - x = self._conv_3(x) - x = self._relu(x) + x = self.conv3(x) + x = self.relu(x) if self.groups == 4: - x = self._conv_4(x) - x = self._relu(x) - x = self._pool(x) + x = self.conv4(x) + x = self.relu(x) + x = self.max_pool(x) return x class VGGNet(TheseusLayer): - def __init__(self, - config, - stop_grad_layers=0, - class_num=1000, - pretrained=False, - **args): + """ + VGGNet + Args: + config: list. VGGNet config. + stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False` + class_num: int=1000. The number of classes. + Returns: + model: nn.Layer. Specific VGG model depends on args. + """ + + def __init__(self, config, stop_grad_layers=0, class_num=1000): super().__init__() self.stop_grad_layers = stop_grad_layers - self._conv_block_1 = ConvBlock(3, 64, config[0]) - self._conv_block_2 = ConvBlock(64, 128, config[1]) - self._conv_block_3 = ConvBlock(128, 256, config[2]) - self._conv_block_4 = ConvBlock(256, 512, config[3]) - self._conv_block_5 = ConvBlock(512, 512, config[4]) + self.conv_block_1 = ConvBlock(3, 64, config[0]) + self.conv_block_2 = ConvBlock(64, 128, config[1]) + self.conv_block_3 = ConvBlock(128, 256, config[2]) + self.conv_block_4 = ConvBlock(256, 512, config[3]) + self.conv_block_5 = ConvBlock(512, 512, config[4]) - self._relu = nn.ReLU() - self._flatten = nn.Flatten(start_axis=1, stop_axis=-1) + self.relu = nn.ReLU() + self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) for idx, block in enumerate([ - self._conv_block_1, self._conv_block_2, self._conv_block_3, - self._conv_block_4, self._conv_block_5 + self.conv_block_1, self.conv_block_2, self.conv_block_3, + self.conv_block_4, self.conv_block_5 ]): if self.stop_grad_layers >= idx + 1: for param in block.parameters(): param.trainable = False - self._drop = Dropout(p=0.5, mode="downscale_in_infer") - self._fc1 = Linear(7 * 7 * 512, 4096) - self._fc2 = Linear(4096, 4096) - self._out = Linear(4096, class_num) - - if pretrained is not None: - load_dygraph_pretrain(self, pretrained) + self.drop = Dropout(p=0.5, mode="downscale_in_infer") + self.fc1 = Linear(7 * 7 * 512, 4096) + self.fc2 = Linear(4096, 4096) + self.fc3 = Linear(4096, class_num) def forward(self, inputs): - x = self._conv_block_1(inputs) - x = self._conv_block_2(x) - x = self._conv_block_3(x) - x = self._conv_block_4(x) - x = self._conv_block_5(x) - x = self._flatten(x) - x = self._fc1(x) - x = self._relu(x) - x = self._drop(x) - x = self._fc2(x) - x = self._relu(x) - x = self._drop(x) - x = self._out(x) + x = self.conv_block_1(inputs) + x = self.conv_block_2(x) + x = self.conv_block_3(x) + x = self.conv_block_4(x) + x = self.conv_block_5(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.drop(x) + x = self.fc2(x) + x = self.relu(x) + x = self.drop(x) + x = self.fc3(x) return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def VGG11(pretrained=False, use_ssld=False, **kwargs): + """ + VGG11 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `VGG11` model depends on args. + """ + model = VGGNet(config=NET_CONFIG[11], **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["VGG11"], use_ssld) + return model + + +def VGG13(pretrained=False, use_ssld=False, **kwargs): + """ + VGG13 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `VGG13` model depends on args. + """ + model = VGGNet(config=NET_CONFIG[13], **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["VGG13"], use_ssld) + return model + + +def VGG16(pretrained=False, use_ssld=False, **kwargs): + """ + VGG16 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `VGG16` model depends on args. + """ + model = VGGNet(config=NET_CONFIG[16], **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["VGG16"], use_ssld) + return model + + +def VGG19(pretrained=False, use_ssld=False, **kwargs): + """ + VGG19 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `VGG19` model depends on args. + """ + model = VGGNet(config=NET_CONFIG[19], **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["VGG19"], use_ssld) + return model