diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index ea0c791e25e7169f86c709d863f12238a6f1db4f..b43f38f5868cd78a9ba821154c7c6247b7572a0d 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer +from .layers import create_conv2d, drop_path, make_divisible, create_act_layer from .layers.activations import sigmoid __all__ = [ @@ -19,31 +19,32 @@ class SqueezeExcite(nn.Module): Args: in_chs (int): input channels to layer - se_ratio (float): ratio of squeeze reduction + rd_ratio (float): ratio of squeeze reduction act_layer (nn.Module): activation layer of containing block - gate_fn (Callable): attention gate function + gate_layer (Callable): attention gate function force_act_layer (nn.Module): override block's activation fn if this is set/bound - round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs + rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs """ def __init__( - self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid, - force_act_layer=None, round_chs_fn=None): + self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, + gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): super(SqueezeExcite, self).__init__() - round_chs_fn = round_chs_fn or round - reduced_chs = round_chs_fn(in_chs * se_ratio) + if rd_channels is None: + rd_round_fn = rd_round_fn or round + rd_channels = rd_round_fn(in_chs * rd_ratio) act_layer = force_act_layer or act_layer - self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) self.act1 = create_act_layer(act_layer, inplace=True) - self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - self.gate_fn = get_act_fn(gate_fn) + self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) - return x * self.gate_fn(x_se) + return x * self.gate(x_se) class ConvBnAct(nn.Module): @@ -85,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module): """ def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', - noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + se_layer=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() - has_se = se_layer is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_path_rate = drop_path_rate @@ -99,7 +99,7 @@ class DepthwiseSeparableConv(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() + self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs) @@ -144,12 +144,11 @@ class InvertedResidual(nn.Module): def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): super(InvertedResidual, self).__init__() conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - has_se = se_layer is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate @@ -166,7 +165,7 @@ class InvertedResidual(nn.Module): self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = se_layer(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) @@ -212,8 +211,8 @@ class CondConvResidual(InvertedResidual): def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) @@ -221,8 +220,8 @@ class CondConvResidual(InvertedResidual): super(CondConvResidual, self).__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer, - norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate) + pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate) self.routing_fn = nn.Linear(in_chs, self.num_experts) @@ -271,8 +270,8 @@ class EdgeResidual(nn.Module): def __init__( self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', - force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(EdgeResidual, self).__init__() if force_in_chs > 0: mid_chs = make_divisible(force_in_chs * exp_ratio) @@ -289,7 +288,7 @@ class EdgeResidual(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 35019747a4426c856a43181a1709ae39e4fb482d..f44cf15820cf0b9b617309799dd33fc4e1983888 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -10,11 +10,12 @@ import logging import math import re from copy import deepcopy +from functools import partial import torch.nn as nn from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible +from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] @@ -120,7 +121,9 @@ def _decode_block_str(block_str): elif v == 'hs': value = get_act_layer('hard_swish') elif v == 'sw': - value = get_act_layer('swish') + value = get_act_layer('swish') # aka SiLU + elif v == 'mi': + value = get_act_layer('mish') else: continue options[key] = value @@ -273,7 +276,12 @@ class EfficientNetBuilder: self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs self.act_layer = act_layer self.norm_layer = norm_layer - self.se_layer = se_layer + self.se_layer = get_attn(se_layer) + try: + self.se_layer(8, rd_ratio=1.0) + self.se_has_ratio = True + except RuntimeError as e: + self.se_has_ratio = False self.drop_path_rate = drop_path_rate if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense @@ -300,18 +308,21 @@ class EfficientNetBuilder: ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer assert ba['act_layer'] is not None ba['norm_layer'] = self.norm_layer + ba['drop_path_rate'] = drop_path_rate if bt != 'cn': - ba['se_layer'] = self.se_layer - if not self.se_from_exp and ba['se_ratio']: - ba['se_ratio'] /= ba.get('exp_ratio', 1.0) - ba['drop_path_rate'] = drop_path_rate + se_ratio = ba.pop('se_ratio') + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + # adjust se_ratio by expansion ratio if calculating se channels from block input + se_ratio /= ba.get('exp_ratio', 1.0) + if self.se_has_ratio: + ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) + else: + ba['se_layer'] = self.se_layer if bt == 'ir': _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - if ba.get('num_experts', 0) > 0: - block = CondConvResidual(**ba) - else: - block = InvertedResidual(**ba) + block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = DepthwiseSeparableConv(**ba) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index d82a91b49871efdd0a261ca043d3fadc70136a7a..48dee6ecd121c50c4636b6f54c8d82bbf4c901d0 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -40,7 +40,7 @@ default_cfgs = { } -_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', round_chs_fn=partial(make_divisible, divisor=4)) +_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4)) class GhostModule(nn.Module): @@ -92,7 +92,7 @@ class GhostBottleneck(nn.Module): self.bn_dw = None # Squeeze-and-excitation - self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None # Point-wise linear projection self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 16b9c4bca90f627c0c67352fa47b2df6612079a5..9988a0444558d9e7f4b640ff468cc63b1dc1d7f4 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -39,8 +39,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): """ num_features = 1280 - se_layer = partial( - SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index fad88aa742d3938994ad0a42bbc968765581ca69..e85112e6c285bea3d9000b46ff4a8d0f7bfe6f98 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -266,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw round_chs_fn=partial(round_channels, multiplier=channel_multiplier), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'hard_swish'), - se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid')), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'), **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs) @@ -354,8 +354,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] - se_layer = partial( - SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, @@ -372,67 +371,48 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ FBNetV3 + Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` + - https://arxiv.org/abs/2006.02049 FIXME untested, this is a preliminary impl of some FBNet-V3 variants. """ vl = variant.split('_')[-1] if vl in ('a', 'b'): stem_size = 16 arch_def = [ - # stage 0, 112x112 in ['ds_r2_k3_s1_e1_c16'], - # stage 1, 112x112 in ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'], - # stage 2, 56x56 in ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'], - # stage 3, 28x28 in ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], - # stage 4, 14x14in ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'], - # stage 5, 14x14in ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'], - # stage 6, 7x7 in ['cn_r1_k1_s1_c1344'], ] elif vl == 'd': stem_size = 24 arch_def = [ - # stage 0, 112x112 in ['ds_r2_k3_s1_e1_c16'], - # stage 1, 112x112 in ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'], - # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'], - # stage 3, 28x28 in ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], - # stage 4, 14x14in ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'], - # stage 5, 14x14in ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'], - # stage 6, 7x7 in ['cn_r1_k1_s1_c1440'], ] elif vl == 'g': stem_size = 32 arch_def = [ - # stage 0, 112x112 in ['ds_r3_k3_s1_e1_c24'], - # stage 1, 112x112 in ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'], - # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'], - # stage 3, 28x28 in ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'], - # stage 4, 14x14in ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'], - # stage 5, 14x14in ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'], - # stage 6, 7x7 in - ['cn_r1_k1_s1_c1728'], # hard-swish + ['cn_r1_k1_s1_c1728'], ] else: raise NotImplemented round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95) - se_layer = partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), round_chs_fn=round_chs_fn) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn) act_layer = resolve_act_layer(kwargs, 'hard_swish') model_kwargs = dict( block_args=decode_arch_def(arch_def),