diff --git a/docs/zh_CN/models/LeViT.md b/docs/zh_CN/models/LeViT.md new file mode 100644 index 0000000000000000000000000000000000000000..d668a27b09f3151f2a7b2f428f92a3e4c8a3d97e --- /dev/null +++ b/docs/zh_CN/models/LeViT.md @@ -0,0 +1,17 @@ +# LeViT + +## 概述 +LeViT是一种快速推理的、用于图像分类任务的混合神经网络。其设计之初考虑了网络模型在不同的硬件平台上的性能,因此能够更好地反映普遍应用的真实场景。通过大量实验,作者找到了卷积神经网络与Transformer体系更好的结合方式,并且提出了attention-based方法,用于整合Transformer中的位置信息编码。[论文地址](https://arxiv.org/abs/2104.01136)。 + +## 精度、FLOPS和参数量 + +| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(M) | Params
(M) | +|:--:|:--:|:--:|:--:|:--:|:--:|:--:| +| LeViT-128S | 0.7621 | 0.9277 | 0.766 | 0.929 | 305 | 7.8 | +| LeViT-128 | 0.7833 | 0.9378 | 0.786 | 0.940 | 406 | 9.2 | +| LeViT-192 | 0.7963 | 0.9460 | 0.800 | 0.947 | 658 | 11 | +| LeViT-256 | 0.8085 | 0.9497 | 0.816 | 0.954 | 1120 | 19 | +| LeViT-384 | 0.8234 | 0.9587 | 0.826 | 0.960 | 2353 | 39 | + + +**注**:与Reference的精度差异源于数据预处理不同。 diff --git a/docs/zh_CN/models/Twins.md b/docs/zh_CN/models/Twins.md new file mode 100644 index 0000000000000000000000000000000000000000..424f3985df00216c048e026632c43f9e720f4542 --- /dev/null +++ b/docs/zh_CN/models/Twins.md @@ -0,0 +1,17 @@ +# Twins + +## 概述 +Twins网络包括Twins-PCPVT和Twins-SVT,其重点对空间注意力机制进行了精心设计,得到了简单却更为有效的方案。由于该体系结构仅涉及矩阵乘法,而目前的深度学习框架中对矩阵乘法有较高的优化程度,因此该体系结构十分高效且易于实现。并且,该体系结构在图像分类、目标检测和语义分割等多种下游视觉任务中都能够取得优异的性能。[论文地址](https://arxiv.org/abs/2104.13840)。 + +## 精度、FLOPS和参数量 + +| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) | +|:--:|:--:|:--:|:--:|:--:|:--:|:--:| +| pcpvt_small | 0.8082 | 0.9552 | 0.812 | - | 3.7 | 24.1 | +| pcpvt_base | 0.8242 | 0.9619 | 0.827 | - | 6.4 | 43.8 | +| pcpvt_large | 0.8273 | 0.9650 | 0.831 | - | 9.5 | 60.9 | +| alt_gvt_small | 0.8140 | 0.9546 | 0.817 | - | 2.8 | 24 | +| alt_gvt_base | 0.8294 | 0.9621 | 0.832 | - | 8.3 | 56 | +| alt_gvt_large | 0.8331 | 0.9642 | 0.837 | - | 14.8 | 99.2 | + +**注**:与Reference的精度差异源于数据预处理不同。 diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 256a9950b70e00952d7ae4527ed26afe49325e71..4b5b08af22151266d12ed712d40abc3fc503be10 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -47,6 +47,8 @@ from ppcls.arch.backbone.model_zoo.distillation_models import ResNet50_vd_distil from ppcls.arch.backbone.model_zoo.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384 from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0 +from ppcls.arch.backbone.model_zoo.gvt import pcpvt_small, pcpvt_base, pcpvt_large, alt_gvt_small, alt_gvt_base, alt_gvt_large +from ppcls.arch.backbone.model_zoo.levit import LeViT_128S, LeViT_128, LeViT_192, LeViT_256, LeViT_384 from ppcls.arch.backbone.model_zoo.dla import DLA34, DLA46_c, DLA46x_c, DLA60, DLA60x, DLA60x_c, DLA102, DLA102x, DLA102x2, DLA169 from ppcls.arch.backbone.model_zoo.rednet import RedNet26, RedNet38, RedNet50, RedNet101, RedNet152 from ppcls.arch.backbone.model_zoo.tnt import TNT_small diff --git a/ppcls/arch/backbone/model_zoo/gvt.py b/ppcls/arch/backbone/model_zoo/gvt.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfdfead63c86ad4e6029edd51c2b955c540a58c --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/gvt.py @@ -0,0 +1,659 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.regularizer import L2Decay + +from .vision_transformer import trunc_normal_, normal_, zeros_, ones_, to_2tuple, DropPath, Identity, Mlp +from .vision_transformer import Block as ViTBlock + +__all__ = [ + "CPVTV2", "PCPVT", "ALTGVT", "pcpvt_small", "pcpvt_base", "pcpvt_large", + "alt_gvt_small", "alt_gvt_base", "alt_gvt_large" +] + + +class GroupAttention(nn.Layer): + """LSA: self attention within a group. + """ + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + ws=1): + super().__init__() + if ws == 1: + raise Exception(f"ws {ws} should not be 1") + if dim % num_heads != 0: + raise Exception( + f"dim {dim} should be divided by num_heads {num_heads}.") + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, H, W): + B, N, C = x.shape + h_group, w_group = H // self.ws, W // self.ws + total_groups = h_group * w_group + x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose( + [0, 1, 3, 2, 4, 5]) + qkv = self.qkv(x).reshape( + [B, total_groups, -1, 3, self.num_heads, + C // self.num_heads]).transpose([3, 0, 1, 4, 2, 5]) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale + + attn = nn.Softmax(axis=-1)(attn) + attn = self.attn_drop(attn) + attn = (attn @v).transpose([0, 1, 3, 2, 4]).reshape( + [B, h_group, w_group, self.ws, self.ws, C]) + + x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention(nn.Layer): + """GSA: using a key to summarize the information for a group to be efficient. + """ + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias_attr=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2D( + dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape( + [B, N, self.num_heads, C // self.num_heads]).transpose( + [0, 2, 1, 3]) + + if self.sr_ratio > 1: + x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W]) + x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1]) + x_ = self.norm(x_) + kv = self.kv(x_).reshape( + [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( + [2, 0, 3, 1, 4]) + else: + kv = self.kv(x).reshape( + [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( + [2, 0, 3, 1, 4]) + k, v = kv[0], kv[1] + + attn = (q @k.transpose([0, 1, 3, 2])) * self.scale + attn = nn.Softmax(axis=-1)(attn) + attn = self.attn_drop(attn) + + x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio) + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class SBlock(ViTBlock): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1): + super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, + attn_drop, drop_path, act_layer, norm_layer) + + def forward(self, x, H, W): + return super().forward(x) + + +class GroupBlock(ViTBlock): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=1): + super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, + attn_drop, drop_path, act_layer, norm_layer) + del self.attn + if ws == 1: + self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, + attn_drop, drop, sr_ratio) + else: + self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale, + attn_drop, drop, ws) + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding. + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + if img_size % patch_size != 0: + raise Exception( + f"img_size {img_size} should be divided by patch_size {patch_size}." + ) + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[ + 1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose([0, 2, 1]) + x = self.norm(x) + H, W = H // self.patch_size[0], W // self.patch_size[1] + return x, (H, W) + + +# borrow from PVT https://github.com/whai362/PVT.git +class PyramidVisionTransformer(nn.Layer): + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + block_cls=Block): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embeds = nn.LayerList() + self.pos_embeds = nn.ParameterList() + self.pos_drops = nn.LayerList() + self.blocks = nn.LayerList() + + for i in range(len(depths)): + if i == 0: + self.patch_embeds.append( + PatchEmbed(img_size, patch_size, in_chans, embed_dims[i])) + else: + self.patch_embeds.append( + PatchEmbed(img_size // patch_size // 2**(i - 1), 2, + embed_dims[i - 1], embed_dims[i])) + patch_num = self.patch_embeds[i].num_patches + 1 if i == len( + embed_dims) - 1 else self.patch_embeds[i].num_patches + self.pos_embeds.append( + self.create_parameter( + shape=[1, patch_num, embed_dims[i]], + default_initializer=zeros_)) + self.add_parameter(f"pos_embeds_{i}", self.pos_embeds[i]) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + + dpr = [ + x.numpy()[0] + for x in paddle.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + cur = 0 + for k in range(len(depths)): + _block = nn.LayerList([ + block_cls( + dim=embed_dims[k], + num_heads=num_heads[k], + mlp_ratio=mlp_ratios[k], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[k]) for i in range(depths[k]) + ]) + self.blocks.append(_block) + cur += depths[k] + + self.norm = norm_layer(embed_dims[-1]) + + # cls_token + self.cls_token = self.create_parameter( + shape=[1, 1, embed_dims[-1]], + default_initializer=zeros_, + attr=paddle.ParamAttr(regularizer=L2Decay(0.0))) + self.add_parameter("cls_token", self.cls_token) + + # classification head + self.head = nn.Linear(embed_dims[-1], + num_classes) if num_classes > 0 else Identity() + + # init weights + for pos_emb in self.pos_embeds: + trunc_normal_(pos_emb) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + B = x.shape[0] + for i in range(len(self.depths)): + x, (H, W) = self.patch_embeds[i](x) + if i == len(self.depths) - 1: + cls_tokens = self.cls_token.expand([B, -1, -1]) + x = paddle.concat([cls_tokens, x], dim=1) + x = x + self.pos_embeds[i] + x = self.pos_drops[i](x) + for blk in self.blocks[i]: + x = blk(x, H, W) + if i < len(self.depths) - 1: + x = x.reshape([B, H, W, -1]).transpose( + [0, 3, 1, 2]).contiguous() + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +# PEG from https://arxiv.org/abs/2102.10882 +class PosCNN(nn.Layer): + def __init__(self, in_chans, embed_dim=768, s=1): + super().__init__() + self.proj = nn.Sequential( + nn.Conv2D( + in_chans, + embed_dim, + 3, + s, + 1, + bias_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)), + groups=embed_dim, + weight_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)), )) + self.s = s + + def forward(self, x, H, W): + B, N, C = x.shape + feat_token = x + cnn_feat = feat_token.transpose([0, 2, 1]).reshape([B, C, H, W]) + if self.s == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose([0, 2, 1]) + return x + + +class CPVTV2(PyramidVisionTransformer): + """ + Use useful results from CPVT. PEG and GAP. + Therefore, cls token is no longer required. + PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution + changes during the training (such as segmentation, detection) + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + block_cls=Block): + super().__init__(img_size, patch_size, in_chans, num_classes, + embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer, + depths, sr_ratios, block_cls) + del self.pos_embeds + del self.cls_token + self.pos_block = nn.LayerList( + [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims]) + self.apply(self._init_weights) + + def _init_weights(self, m): + import math + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + elif isinstance(m, nn.Conv2D): + fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + fan_out //= m._groups + normal_(0, math.sqrt(2.0 / fan_out))(m.weight) + if m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2D): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + + for i in range(len(self.depths)): + x, (H, W) = self.patch_embeds[i](x) + x = self.pos_drops[i](x) + + for j, blk in enumerate(self.blocks[i]): + x = blk(x, H, W) + if j == 0: + x = self.pos_block[i](x, H, W) # PEG here + + if i < len(self.depths) - 1: + x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2]) + + x = self.norm(x) + return x.mean(axis=1) # GAP here + + +class PCPVT(CPVTV2): + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + block_cls=SBlock): + super().__init__(img_size, patch_size, in_chans, num_classes, + embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer, + depths, sr_ratios, block_cls) + + +class ALTGVT(PCPVT): + """ + alias Twins-SVT + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + class_dim=1000, + embed_dims=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + block_cls=GroupBlock, + wss=[7, 7, 7]): + super().__init__(img_size, patch_size, in_chans, class_dim, embed_dims, + num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate, + attn_drop_rate, drop_path_rate, norm_layer, depths, + sr_ratios, block_cls) + del self.blocks + self.wss = wss + # transformer encoder + dpr = [ + x.numpy()[0] + for x in paddle.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + self.blocks = nn.LayerList() + for k in range(len(depths)): + _block = nn.LayerList([ + block_cls( + dim=embed_dims[k], + num_heads=num_heads[k], + mlp_ratio=mlp_ratios[k], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[k], + ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k]) + ]) + self.blocks.append(_block) + cur += depths[k] + self.apply(self._init_weights) + + +def pcpvt_small(pretrained=False, **kwargs): + model = CPVTV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs) + + return model + + +def pcpvt_base(pretrained=False, **kwargs): + model = CPVTV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs) + + return model + + +def pcpvt_large(pretrained=False, **kwargs): + model = CPVTV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs) + + return model + + +def alt_gvt_small(pretrained=False, **kwargs): + model = ALTGVT( + patch_size=4, + embed_dims=[64, 128, 256, 512], + num_heads=[2, 4, 8, 16], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[2, 2, 10, 4], + wss=[7, 7, 7, 7], + sr_ratios=[8, 4, 2, 1], + **kwargs) + + return model + + +def alt_gvt_base(pretrained=False, **args): + model = ALTGVT( + patch_size=4, + embed_dims=[96, 192, 384, 768], + num_heads=[3, 6, 12, 24], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[2, 2, 18, 2], + wss=[7, 7, 7, 7], + sr_ratios=[8, 4, 2, 1], + **args) + + return model + + +def alt_gvt_large(pretrained=False, **kwargs): + model = ALTGVT( + patch_size=4, + embed_dims=[128, 256, 512, 1024], + num_heads=[4, 8, 16, 32], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[2, 2, 18, 2], + wss=[7, 7, 7, 7], + sr_ratios=[8, 4, 2, 1], + **kwargs) + + return model diff --git a/ppcls/arch/backbone/model_zoo/levit.py b/ppcls/arch/backbone/model_zoo/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..459afcd6f13c43b87156a81fe0913efa437b565c --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/levit.py @@ -0,0 +1,515 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import math +import warnings + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import TruncatedNormal, Constant +from paddle.regularizer import L2Decay + +from .vision_transformer import trunc_normal_, zeros_, ones_, Identity + +__all__ = ["LeViT_128S", "LeViT_128", "LeViT_192", "LeViT_256", "LeViT_384"] + + +def cal_attention_biases(attention_biases, attention_bias_idxs): + gather_list = [] + attention_bias_t = paddle.transpose(attention_biases, (1, 0)) + for idx in attention_bias_idxs: + gather = paddle.gather(attention_bias_t, idx) + gather_list.append(gather) + shape0, shape1 = attention_bias_idxs.shape + return paddle.transpose(paddle.concat(gather_list), (1, 0)).reshape( + (0, shape0, shape1)) + + +class Conv2d_BN(nn.Sequential): + def __init__(self, + a, + b, + ks=1, + stride=1, + pad=0, + dilation=1, + groups=1, + bn_weight_init=1, + resolution=-10000): + super().__init__() + self.add_sublayer( + 'c', + nn.Conv2D( + a, b, ks, stride, pad, dilation, groups, bias_attr=False)) + bn = nn.BatchNorm2D(b) + ones_(bn.weight) + zeros_(bn.bias) + self.add_sublayer('bn', bn) + + +class Linear_BN(nn.Sequential): + def __init__(self, a, b, bn_weight_init=1): + super().__init__() + self.add_sublayer('c', nn.Linear(a, b, bias_attr=False)) + bn = nn.BatchNorm1D(b) + ones_(bn.weight) + zeros_(bn.bias) + self.add_sublayer('bn', bn) + + def forward(self, x): + l, bn = self._sub_layers.values() + x = l(x) + return paddle.reshape(bn(x.flatten(0, 1)), x.shape) + + +class BN_Linear(nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_sublayer('bn', nn.BatchNorm1D(a)) + l = nn.Linear(a, b, bias_attr=bias) + trunc_normal_(l.weight) + if bias: + zeros_(l.bias) + self.add_sublayer('l', l) + + +def b16(n, activation, resolution=224): + return nn.Sequential( + Conv2d_BN( + 3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + Conv2d_BN( + n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + Conv2d_BN( + n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + Conv2d_BN( + n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(nn.Layer): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * paddle.rand( + x.size(0), 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(nn.Layer): + def __init__(self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + activation=None, + resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + self.h = self.dh + nh_kd * 2 + self.qkv = Linear_BN(dim, self.h) + self.proj = nn.Sequential( + activation(), Linear_BN( + self.dh, dim, bn_weight_init=0)) + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = self.create_parameter( + shape=(num_heads, len(attention_offsets)), + default_initializer=zeros_, + attr=paddle.ParamAttr(regularizer=L2Decay(0.0))) + tensor_idxs = paddle.to_tensor(idxs, dtype='int64') + self.register_buffer('attention_bias_idxs', + paddle.reshape(tensor_idxs, [N, N])) + + @paddle.no_grad() + def train(self, mode=True): + if mode: + super().train() + else: + super().eval() + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = cal_attention_biases(self.attention_biases, + self.attention_bias_idxs) + + def forward(self, x): + self.training = True + B, N, C = x.shape + qkv = self.qkv(x) + qkv = paddle.reshape(qkv, + [B, N, self.num_heads, self.h // self.num_heads]) + q, k, v = paddle.split( + qkv, [self.key_dim, self.key_dim, self.d], axis=3) + q = paddle.transpose(q, perm=[0, 2, 1, 3]) + k = paddle.transpose(k, perm=[0, 2, 1, 3]) + v = paddle.transpose(v, perm=[0, 2, 1, 3]) + k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2]) + + if self.training: + attention_biases = cal_attention_biases(self.attention_biases, + self.attention_bias_idxs) + else: + attention_biases = self.ab + attn = ((q @k_transpose) * self.scale + attention_biases) + attn = F.softmax(attn) + x = paddle.transpose(attn @v, perm=[0, 2, 1, 3]) + x = paddle.reshape(x, [B, N, self.dh]) + x = self.proj(x) + return x + + +class Subsample(nn.Layer): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = paddle.reshape(x, [B, self.resolution, self.resolution, + C])[:, ::self.stride, ::self.stride] + x = paddle.reshape(x, [B, -1, C]) + return x + + +class AttentionSubsample(nn.Layer): + def __init__(self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=2, + activation=None, + stride=2, + resolution=14, + resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_**2 + self.training = True + h = self.dh + nh_kd + self.kv = Linear_BN(in_dim, h) + + self.q = nn.Sequential( + Subsample(stride, resolution), Linear_BN(in_dim, nh_kd)) + self.proj = nn.Sequential(activation(), Linear_BN(self.dh, out_dim)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list( + itertools.product(range(resolution_), range(resolution_))) + + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + i = 0 + j = 0 + for p1 in points_: + i += 1 + for p2 in points: + j += 1 + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = self.create_parameter( + shape=(num_heads, len(attention_offsets)), + default_initializer=zeros_, + attr=paddle.ParamAttr(regularizer=L2Decay(0.0))) + + tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64') + self.register_buffer('attention_bias_idxs', + paddle.reshape(tensor_idxs_, [N_, N])) + + @paddle.no_grad() + def train(self, mode=True): + if mode: + super().train() + else: + super().eval() + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = cal_attention_biases(self.attention_biases, + self.attention_bias_idxs) + + def forward(self, x): + self.training = True + B, N, C = x.shape + kv = self.kv(x) + kv = paddle.reshape(kv, [B, N, self.num_heads, -1]) + k, v = paddle.split(kv, [self.key_dim, self.d], axis=3) + k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC + v = paddle.transpose(v, perm=[0, 2, 1, 3]) + q = paddle.reshape( + self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim]) + q = paddle.transpose(q, perm=[0, 2, 1, 3]) + + if self.training: + attention_biases = cal_attention_biases(self.attention_biases, + self.attention_bias_idxs) + else: + attention_biases = self.ab + + attn = (q @paddle.transpose( + k, perm=[0, 1, 3, 2])) * self.scale + attention_biases + attn = F.softmax(attn) + + x = paddle.reshape( + paddle.transpose( + (attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh]) + x = self.proj(x) + return x + + +class LeViT(nn.Layer): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + class_dim=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attention_activation=nn.Hardswish, + mlp_activation=nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + + self.class_dim = class_dim + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, + mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention( + ed, + kd, + nh, + attn_ratio=ar, + activation=attention_activation, + resolution=resolution, ), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual( + nn.Sequential( + Linear_BN(ed, h), + mlp_activation(), + Linear_BN( + h, ed, bn_weight_init=0), ), + drop_path)) + if do[0] == 'Subsample': + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], + key_dim=do[1], + num_heads=do[2], + attn_ratio=do[3], + activation=attention_activation, + stride=do[5], + resolution=resolution, + resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual( + nn.Sequential( + Linear_BN(embed_dim[i + 1], h), + mlp_activation(), + Linear_BN( + h, embed_dim[i + 1], bn_weight_init=0), ), + drop_path)) + self.blocks = nn.Sequential(*self.blocks) + + # Classifier head + self.head = BN_Linear(embed_dim[-1], + class_dim) if class_dim > 0 else Identity() + if distillation: + self.head_dist = BN_Linear( + embed_dim[-1], class_dim) if class_dim > 0 else Identity() + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2) + x = paddle.transpose(x, perm=[0, 2, 1]) + x = self.blocks(x) + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, class_dim, distillation): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = nn.Hardswish + model = LeViT( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attention_activation=act, + mlp_activation=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + class_dim=class_dim, + drop_path=drop_path, + distillation=distillation) + + return model + + +specification = { + 'LeViT_128S': { + 'C': '128_256_384', + 'D': 16, + 'N': '4_6_8', + 'X': '2_3_4', + 'drop_path': 0 + }, + 'LeViT_128': { + 'C': '128_256_384', + 'D': 16, + 'N': '4_8_12', + 'X': '4_4_4', + 'drop_path': 0 + }, + 'LeViT_192': { + 'C': '192_288_384', + 'D': 32, + 'N': '3_5_6', + 'X': '4_4_4', + 'drop_path': 0 + }, + 'LeViT_256': { + 'C': '256_384_512', + 'D': 32, + 'N': '4_6_8', + 'X': '4_4_4', + 'drop_path': 0 + }, + 'LeViT_384': { + 'C': '384_512_768', + 'D': 32, + 'N': '6_9_12', + 'X': '4_4_4', + 'drop_path': 0.1 + }, +} + + +def LeViT_128S(class_dim=1000, distillation=True, pretrained=False): + return model_factory( + **specification['LeViT_128S'], + class_dim=class_dim, + distillation=distillation) + + +def LeViT_128(class_dim=1000, distillation=True): + return model_factory( + **specification['LeViT_128'], + class_dim=class_dim, + distillation=distillation) + + +def LeViT_192(class_dim=1000, distillation=True): + return model_factory( + **specification['LeViT_192'], + class_dim=class_dim, + distillation=distillation) + + +def LeViT_256(class_dim=1000, distillation=False): + return model_factory( + **specification['LeViT_256'], + class_dim=class_dim, + distillation=distillation) + + +def LeViT_384(class_dim=1000, distillation=True): + return model_factory( + **specification['LeViT_384'], + class_dim=class_dim, + distillation=distillation) diff --git a/ppcls/arch/backbone/model_zoo/vision_transformer.py b/ppcls/arch/backbone/model_zoo/vision_transformer.py index 32f198913d59014326957cc6fe7f9b325a59ef28..22cc0ad2790a15aeef4a34c107b8eb27308aebb4 100644 --- a/ppcls/arch/backbone/model_zoo/vision_transformer.py +++ b/ppcls/arch/backbone/model_zoo/vision_transformer.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import Callable + import numpy as np import paddle import paddle.nn as nn -from paddle.nn.initializer import TruncatedNormal, Constant +from paddle.nn.initializer import TruncatedNormal, Constant, Normal __all__ = [ "VisionTransformer", "ViT_small_patch16_224", "ViT_base_patch16_224", @@ -25,6 +27,7 @@ __all__ = [ ] trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal zeros_ = Constant(value=0.) ones_ = Constant(value=1.) @@ -141,7 +144,13 @@ class Block(nn.Layer): norm_layer='nn.LayerNorm', epsilon=1e-5): super().__init__() - self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm1 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") self.attn = Attention( dim, num_heads=num_heads, @@ -151,7 +160,13 @@ class Block(nn.Layer): proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() - self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm2 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,