diff --git a/docs/zh_CN/models/LeViT.md b/docs/zh_CN/models/LeViT.md
new file mode 100644
index 0000000000000000000000000000000000000000..d668a27b09f3151f2a7b2f428f92a3e4c8a3d97e
--- /dev/null
+++ b/docs/zh_CN/models/LeViT.md
@@ -0,0 +1,17 @@
+# LeViT
+
+## 概述
+LeViT是一种快速推理的、用于图像分类任务的混合神经网络。其设计之初考虑了网络模型在不同的硬件平台上的性能,因此能够更好地反映普遍应用的真实场景。通过大量实验,作者找到了卷积神经网络与Transformer体系更好的结合方式,并且提出了attention-based方法,用于整合Transformer中的位置信息编码。[论文地址](https://arxiv.org/abs/2104.01136)。
+
+## 精度、FLOPS和参数量
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(M) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| LeViT-128S | 0.7621 | 0.9277 | 0.766 | 0.929 | 305 | 7.8 |
+| LeViT-128 | 0.7833 | 0.9378 | 0.786 | 0.940 | 406 | 9.2 |
+| LeViT-192 | 0.7963 | 0.9460 | 0.800 | 0.947 | 658 | 11 |
+| LeViT-256 | 0.8085 | 0.9497 | 0.816 | 0.954 | 1120 | 19 |
+| LeViT-384 | 0.8234 | 0.9587 | 0.826 | 0.960 | 2353 | 39 |
+
+
+**注**:与Reference的精度差异源于数据预处理不同。
diff --git a/docs/zh_CN/models/Twins.md b/docs/zh_CN/models/Twins.md
new file mode 100644
index 0000000000000000000000000000000000000000..424f3985df00216c048e026632c43f9e720f4542
--- /dev/null
+++ b/docs/zh_CN/models/Twins.md
@@ -0,0 +1,17 @@
+# Twins
+
+## 概述
+Twins网络包括Twins-PCPVT和Twins-SVT,其重点对空间注意力机制进行了精心设计,得到了简单却更为有效的方案。由于该体系结构仅涉及矩阵乘法,而目前的深度学习框架中对矩阵乘法有较高的优化程度,因此该体系结构十分高效且易于实现。并且,该体系结构在图像分类、目标检测和语义分割等多种下游视觉任务中都能够取得优异的性能。[论文地址](https://arxiv.org/abs/2104.13840)。
+
+## 精度、FLOPS和参数量
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| pcpvt_small | 0.8082 | 0.9552 | 0.812 | - | 3.7 | 24.1 |
+| pcpvt_base | 0.8242 | 0.9619 | 0.827 | - | 6.4 | 43.8 |
+| pcpvt_large | 0.8273 | 0.9650 | 0.831 | - | 9.5 | 60.9 |
+| alt_gvt_small | 0.8140 | 0.9546 | 0.817 | - | 2.8 | 24 |
+| alt_gvt_base | 0.8294 | 0.9621 | 0.832 | - | 8.3 | 56 |
+| alt_gvt_large | 0.8331 | 0.9642 | 0.837 | - | 14.8 | 99.2 |
+
+**注**:与Reference的精度差异源于数据预处理不同。
diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py
index 256a9950b70e00952d7ae4527ed26afe49325e71..4b5b08af22151266d12ed712d40abc3fc503be10 100644
--- a/ppcls/arch/backbone/__init__.py
+++ b/ppcls/arch/backbone/__init__.py
@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -47,6 +47,8 @@ from ppcls.arch.backbone.model_zoo.distillation_models import ResNet50_vd_distil
from ppcls.arch.backbone.model_zoo.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L
from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
+from ppcls.arch.backbone.model_zoo.gvt import pcpvt_small, pcpvt_base, pcpvt_large, alt_gvt_small, alt_gvt_base, alt_gvt_large
+from ppcls.arch.backbone.model_zoo.levit import LeViT_128S, LeViT_128, LeViT_192, LeViT_256, LeViT_384
from ppcls.arch.backbone.model_zoo.dla import DLA34, DLA46_c, DLA46x_c, DLA60, DLA60x, DLA60x_c, DLA102, DLA102x, DLA102x2, DLA169
from ppcls.arch.backbone.model_zoo.rednet import RedNet26, RedNet38, RedNet50, RedNet101, RedNet152
from ppcls.arch.backbone.model_zoo.tnt import TNT_small
diff --git a/ppcls/arch/backbone/model_zoo/gvt.py b/ppcls/arch/backbone/model_zoo/gvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcfdfead63c86ad4e6029edd51c2b955c540a58c
--- /dev/null
+++ b/ppcls/arch/backbone/model_zoo/gvt.py
@@ -0,0 +1,659 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.regularizer import L2Decay
+
+from .vision_transformer import trunc_normal_, normal_, zeros_, ones_, to_2tuple, DropPath, Identity, Mlp
+from .vision_transformer import Block as ViTBlock
+
+__all__ = [
+ "CPVTV2", "PCPVT", "ALTGVT", "pcpvt_small", "pcpvt_base", "pcpvt_large",
+ "alt_gvt_small", "alt_gvt_base", "alt_gvt_large"
+]
+
+
+class GroupAttention(nn.Layer):
+ """LSA: self attention within a group.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ ws=1):
+ super().__init__()
+ if ws == 1:
+ raise Exception(f"ws {ws} should not be 1")
+ if dim % num_heads != 0:
+ raise Exception(
+ f"dim {dim} should be divided by num_heads {num_heads}.")
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.ws = ws
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ h_group, w_group = H // self.ws, W // self.ws
+ total_groups = h_group * w_group
+ x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose(
+ [0, 1, 3, 2, 4, 5])
+ qkv = self.qkv(x).reshape(
+ [B, total_groups, -1, 3, self.num_heads,
+ C // self.num_heads]).transpose([3, 0, 1, 4, 2, 5])
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale
+
+ attn = nn.Softmax(axis=-1)(attn)
+ attn = self.attn_drop(attn)
+ attn = (attn @v).transpose([0, 1, 3, 2, 4]).reshape(
+ [B, h_group, w_group, self.ws, self.ws, C])
+
+ x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C])
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Attention(nn.Layer):
+ """GSA: using a key to summarize the information for a group to be efficient.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ sr_ratio=1):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2D(
+ dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ q = self.q(x).reshape(
+ [B, N, self.num_heads, C // self.num_heads]).transpose(
+ [0, 2, 1, 3])
+
+ if self.sr_ratio > 1:
+ x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
+ x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
+ x_ = self.norm(x_)
+ kv = self.kv(x_).reshape(
+ [B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
+ [2, 0, 3, 1, 4])
+ else:
+ kv = self.kv(x).reshape(
+ [B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
+ [2, 0, 3, 1, 4])
+ k, v = kv[0], kv[1]
+
+ attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
+ attn = nn.Softmax(axis=-1)(attn)
+ attn = self.attn_drop(attn)
+
+ x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Layer):
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ sr_ratio=1):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ sr_ratio=sr_ratio)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class SBlock(ViTBlock):
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ sr_ratio=1):
+ super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
+ attn_drop, drop_path, act_layer, norm_layer)
+
+ def forward(self, x, H, W):
+ return super().forward(x)
+
+
+class GroupBlock(ViTBlock):
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ sr_ratio=1,
+ ws=1):
+ super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
+ attn_drop, drop_path, act_layer, norm_layer)
+ del self.attn
+ if ws == 1:
+ self.attn = Attention(dim, num_heads, qkv_bias, qk_scale,
+ attn_drop, drop, sr_ratio)
+ else:
+ self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale,
+ attn_drop, drop, ws)
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Layer):
+ """ Image to Patch Embedding.
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ if img_size % patch_size != 0:
+ raise Exception(
+ f"img_size {img_size} should be divided by patch_size {patch_size}."
+ )
+
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.H, self.W = img_size[0] // patch_size[0], img_size[
+ 1] // patch_size[1]
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2D(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = nn.LayerNorm(embed_dim)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x).flatten(2).transpose([0, 2, 1])
+ x = self.norm(x)
+ H, W = H // self.patch_size[0], W // self.patch_size[1]
+ return x, (H, W)
+
+
+# borrow from PVT https://github.com/whai362/PVT.git
+class PyramidVisionTransformer(nn.Layer):
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ block_cls=Block):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+
+ # patch_embed
+ self.patch_embeds = nn.LayerList()
+ self.pos_embeds = nn.ParameterList()
+ self.pos_drops = nn.LayerList()
+ self.blocks = nn.LayerList()
+
+ for i in range(len(depths)):
+ if i == 0:
+ self.patch_embeds.append(
+ PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))
+ else:
+ self.patch_embeds.append(
+ PatchEmbed(img_size // patch_size // 2**(i - 1), 2,
+ embed_dims[i - 1], embed_dims[i]))
+ patch_num = self.patch_embeds[i].num_patches + 1 if i == len(
+ embed_dims) - 1 else self.patch_embeds[i].num_patches
+ self.pos_embeds.append(
+ self.create_parameter(
+ shape=[1, patch_num, embed_dims[i]],
+ default_initializer=zeros_))
+ self.add_parameter(f"pos_embeds_{i}", self.pos_embeds[i])
+ self.pos_drops.append(nn.Dropout(p=drop_rate))
+
+ dpr = [
+ x.numpy()[0]
+ for x in paddle.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ cur = 0
+ for k in range(len(depths)):
+ _block = nn.LayerList([
+ block_cls(
+ dim=embed_dims[k],
+ num_heads=num_heads[k],
+ mlp_ratio=mlp_ratios[k],
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[cur + i],
+ norm_layer=norm_layer,
+ sr_ratio=sr_ratios[k]) for i in range(depths[k])
+ ])
+ self.blocks.append(_block)
+ cur += depths[k]
+
+ self.norm = norm_layer(embed_dims[-1])
+
+ # cls_token
+ self.cls_token = self.create_parameter(
+ shape=[1, 1, embed_dims[-1]],
+ default_initializer=zeros_,
+ attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
+ self.add_parameter("cls_token", self.cls_token)
+
+ # classification head
+ self.head = nn.Linear(embed_dims[-1],
+ num_classes) if num_classes > 0 else Identity()
+
+ # init weights
+ for pos_emb in self.pos_embeds:
+ trunc_normal_(pos_emb)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ for i in range(len(self.depths)):
+ x, (H, W) = self.patch_embeds[i](x)
+ if i == len(self.depths) - 1:
+ cls_tokens = self.cls_token.expand([B, -1, -1])
+ x = paddle.concat([cls_tokens, x], dim=1)
+ x = x + self.pos_embeds[i]
+ x = self.pos_drops[i](x)
+ for blk in self.blocks[i]:
+ x = blk(x, H, W)
+ if i < len(self.depths) - 1:
+ x = x.reshape([B, H, W, -1]).transpose(
+ [0, 3, 1, 2]).contiguous()
+ x = self.norm(x)
+ return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+# PEG from https://arxiv.org/abs/2102.10882
+class PosCNN(nn.Layer):
+ def __init__(self, in_chans, embed_dim=768, s=1):
+ super().__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2D(
+ in_chans,
+ embed_dim,
+ 3,
+ s,
+ 1,
+ bias_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)),
+ groups=embed_dim,
+ weight_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)), ))
+ self.s = s
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ feat_token = x
+ cnn_feat = feat_token.transpose([0, 2, 1]).reshape([B, C, H, W])
+ if self.s == 1:
+ x = self.proj(cnn_feat) + cnn_feat
+ else:
+ x = self.proj(cnn_feat)
+ x = x.flatten(2).transpose([0, 2, 1])
+ return x
+
+
+class CPVTV2(PyramidVisionTransformer):
+ """
+ Use useful results from CPVT. PEG and GAP.
+ Therefore, cls token is no longer required.
+ PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution
+ changes during the training (such as segmentation, detection)
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=4,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ block_cls=Block):
+ super().__init__(img_size, patch_size, in_chans, num_classes,
+ embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
+ drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
+ depths, sr_ratios, block_cls)
+ del self.pos_embeds
+ del self.cls_token
+ self.pos_block = nn.LayerList(
+ [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims])
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ import math
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ elif isinstance(m, nn.Conv2D):
+ fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ fan_out //= m._groups
+ normal_(0, math.sqrt(2.0 / fan_out))(m.weight)
+ if m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2D):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+
+ for i in range(len(self.depths)):
+ x, (H, W) = self.patch_embeds[i](x)
+ x = self.pos_drops[i](x)
+
+ for j, blk in enumerate(self.blocks[i]):
+ x = blk(x, H, W)
+ if j == 0:
+ x = self.pos_block[i](x, H, W) # PEG here
+
+ if i < len(self.depths) - 1:
+ x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
+
+ x = self.norm(x)
+ return x.mean(axis=1) # GAP here
+
+
+class PCPVT(CPVTV2):
+ def __init__(self,
+ img_size=224,
+ patch_size=4,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=[64, 128, 256],
+ num_heads=[1, 2, 4],
+ mlp_ratios=[4, 4, 4],
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ depths=[4, 4, 4],
+ sr_ratios=[4, 2, 1],
+ block_cls=SBlock):
+ super().__init__(img_size, patch_size, in_chans, num_classes,
+ embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
+ drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
+ depths, sr_ratios, block_cls)
+
+
+class ALTGVT(PCPVT):
+ """
+ alias Twins-SVT
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=4,
+ in_chans=3,
+ class_dim=1000,
+ embed_dims=[64, 128, 256],
+ num_heads=[1, 2, 4],
+ mlp_ratios=[4, 4, 4],
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ depths=[4, 4, 4],
+ sr_ratios=[4, 2, 1],
+ block_cls=GroupBlock,
+ wss=[7, 7, 7]):
+ super().__init__(img_size, patch_size, in_chans, class_dim, embed_dims,
+ num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
+ attn_drop_rate, drop_path_rate, norm_layer, depths,
+ sr_ratios, block_cls)
+ del self.blocks
+ self.wss = wss
+ # transformer encoder
+ dpr = [
+ x.numpy()[0]
+ for x in paddle.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+ cur = 0
+ self.blocks = nn.LayerList()
+ for k in range(len(depths)):
+ _block = nn.LayerList([
+ block_cls(
+ dim=embed_dims[k],
+ num_heads=num_heads[k],
+ mlp_ratio=mlp_ratios[k],
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[cur + i],
+ norm_layer=norm_layer,
+ sr_ratio=sr_ratios[k],
+ ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])
+ ])
+ self.blocks.append(_block)
+ cur += depths[k]
+ self.apply(self._init_weights)
+
+
+def pcpvt_small(pretrained=False, **kwargs):
+ model = CPVTV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(
+ nn.LayerNorm, epsilon=1e-6),
+ depths=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs)
+
+ return model
+
+
+def pcpvt_base(pretrained=False, **kwargs):
+ model = CPVTV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(
+ nn.LayerNorm, epsilon=1e-6),
+ depths=[3, 4, 18, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs)
+
+ return model
+
+
+def pcpvt_large(pretrained=False, **kwargs):
+ model = CPVTV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(
+ nn.LayerNorm, epsilon=1e-6),
+ depths=[3, 8, 27, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs)
+
+ return model
+
+
+def alt_gvt_small(pretrained=False, **kwargs):
+ model = ALTGVT(
+ patch_size=4,
+ embed_dims=[64, 128, 256, 512],
+ num_heads=[2, 4, 8, 16],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(
+ nn.LayerNorm, epsilon=1e-6),
+ depths=[2, 2, 10, 4],
+ wss=[7, 7, 7, 7],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs)
+
+ return model
+
+
+def alt_gvt_base(pretrained=False, **args):
+ model = ALTGVT(
+ patch_size=4,
+ embed_dims=[96, 192, 384, 768],
+ num_heads=[3, 6, 12, 24],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(
+ nn.LayerNorm, epsilon=1e-6),
+ depths=[2, 2, 18, 2],
+ wss=[7, 7, 7, 7],
+ sr_ratios=[8, 4, 2, 1],
+ **args)
+
+ return model
+
+
+def alt_gvt_large(pretrained=False, **kwargs):
+ model = ALTGVT(
+ patch_size=4,
+ embed_dims=[128, 256, 512, 1024],
+ num_heads=[4, 8, 16, 32],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(
+ nn.LayerNorm, epsilon=1e-6),
+ depths=[2, 2, 18, 2],
+ wss=[7, 7, 7, 7],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs)
+
+ return model
diff --git a/ppcls/arch/backbone/model_zoo/levit.py b/ppcls/arch/backbone/model_zoo/levit.py
new file mode 100644
index 0000000000000000000000000000000000000000..459afcd6f13c43b87156a81fe0913efa437b565c
--- /dev/null
+++ b/ppcls/arch/backbone/model_zoo/levit.py
@@ -0,0 +1,515 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import math
+import warnings
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn.initializer import TruncatedNormal, Constant
+from paddle.regularizer import L2Decay
+
+from .vision_transformer import trunc_normal_, zeros_, ones_, Identity
+
+__all__ = ["LeViT_128S", "LeViT_128", "LeViT_192", "LeViT_256", "LeViT_384"]
+
+
+def cal_attention_biases(attention_biases, attention_bias_idxs):
+ gather_list = []
+ attention_bias_t = paddle.transpose(attention_biases, (1, 0))
+ for idx in attention_bias_idxs:
+ gather = paddle.gather(attention_bias_t, idx)
+ gather_list.append(gather)
+ shape0, shape1 = attention_bias_idxs.shape
+ return paddle.transpose(paddle.concat(gather_list), (1, 0)).reshape(
+ (0, shape0, shape1))
+
+
+class Conv2d_BN(nn.Sequential):
+ def __init__(self,
+ a,
+ b,
+ ks=1,
+ stride=1,
+ pad=0,
+ dilation=1,
+ groups=1,
+ bn_weight_init=1,
+ resolution=-10000):
+ super().__init__()
+ self.add_sublayer(
+ 'c',
+ nn.Conv2D(
+ a, b, ks, stride, pad, dilation, groups, bias_attr=False))
+ bn = nn.BatchNorm2D(b)
+ ones_(bn.weight)
+ zeros_(bn.bias)
+ self.add_sublayer('bn', bn)
+
+
+class Linear_BN(nn.Sequential):
+ def __init__(self, a, b, bn_weight_init=1):
+ super().__init__()
+ self.add_sublayer('c', nn.Linear(a, b, bias_attr=False))
+ bn = nn.BatchNorm1D(b)
+ ones_(bn.weight)
+ zeros_(bn.bias)
+ self.add_sublayer('bn', bn)
+
+ def forward(self, x):
+ l, bn = self._sub_layers.values()
+ x = l(x)
+ return paddle.reshape(bn(x.flatten(0, 1)), x.shape)
+
+
+class BN_Linear(nn.Sequential):
+ def __init__(self, a, b, bias=True, std=0.02):
+ super().__init__()
+ self.add_sublayer('bn', nn.BatchNorm1D(a))
+ l = nn.Linear(a, b, bias_attr=bias)
+ trunc_normal_(l.weight)
+ if bias:
+ zeros_(l.bias)
+ self.add_sublayer('l', l)
+
+
+def b16(n, activation, resolution=224):
+ return nn.Sequential(
+ Conv2d_BN(
+ 3, n // 8, 3, 2, 1, resolution=resolution),
+ activation(),
+ Conv2d_BN(
+ n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
+ activation(),
+ Conv2d_BN(
+ n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
+ activation(),
+ Conv2d_BN(
+ n // 2, n, 3, 2, 1, resolution=resolution // 8))
+
+
+class Residual(nn.Layer):
+ def __init__(self, m, drop):
+ super().__init__()
+ self.m = m
+ self.drop = drop
+
+ def forward(self, x):
+ if self.training and self.drop > 0:
+ return x + self.m(x) * paddle.rand(
+ x.size(0), 1, 1,
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
+ else:
+ return x + self.m(x)
+
+
+class Attention(nn.Layer):
+ def __init__(self,
+ dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4,
+ activation=None,
+ resolution=14):
+ super().__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ self.attn_ratio = attn_ratio
+ self.h = self.dh + nh_kd * 2
+ self.qkv = Linear_BN(dim, self.h)
+ self.proj = nn.Sequential(
+ activation(), Linear_BN(
+ self.dh, dim, bn_weight_init=0))
+ points = list(itertools.product(range(resolution), range(resolution)))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = self.create_parameter(
+ shape=(num_heads, len(attention_offsets)),
+ default_initializer=zeros_,
+ attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
+ tensor_idxs = paddle.to_tensor(idxs, dtype='int64')
+ self.register_buffer('attention_bias_idxs',
+ paddle.reshape(tensor_idxs, [N, N]))
+
+ @paddle.no_grad()
+ def train(self, mode=True):
+ if mode:
+ super().train()
+ else:
+ super().eval()
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ self.ab = cal_attention_biases(self.attention_biases,
+ self.attention_bias_idxs)
+
+ def forward(self, x):
+ self.training = True
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = paddle.reshape(qkv,
+ [B, N, self.num_heads, self.h // self.num_heads])
+ q, k, v = paddle.split(
+ qkv, [self.key_dim, self.key_dim, self.d], axis=3)
+ q = paddle.transpose(q, perm=[0, 2, 1, 3])
+ k = paddle.transpose(k, perm=[0, 2, 1, 3])
+ v = paddle.transpose(v, perm=[0, 2, 1, 3])
+ k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2])
+
+ if self.training:
+ attention_biases = cal_attention_biases(self.attention_biases,
+ self.attention_bias_idxs)
+ else:
+ attention_biases = self.ab
+ attn = ((q @k_transpose) * self.scale + attention_biases)
+ attn = F.softmax(attn)
+ x = paddle.transpose(attn @v, perm=[0, 2, 1, 3])
+ x = paddle.reshape(x, [B, N, self.dh])
+ x = self.proj(x)
+ return x
+
+
+class Subsample(nn.Layer):
+ def __init__(self, stride, resolution):
+ super().__init__()
+ self.stride = stride
+ self.resolution = resolution
+
+ def forward(self, x):
+ B, N, C = x.shape
+ x = paddle.reshape(x, [B, self.resolution, self.resolution,
+ C])[:, ::self.stride, ::self.stride]
+ x = paddle.reshape(x, [B, -1, C])
+ return x
+
+
+class AttentionSubsample(nn.Layer):
+ def __init__(self,
+ in_dim,
+ out_dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=2,
+ activation=None,
+ stride=2,
+ resolution=14,
+ resolution_=7):
+ super().__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * self.num_heads
+ self.attn_ratio = attn_ratio
+ self.resolution_ = resolution_
+ self.resolution_2 = resolution_**2
+ self.training = True
+ h = self.dh + nh_kd
+ self.kv = Linear_BN(in_dim, h)
+
+ self.q = nn.Sequential(
+ Subsample(stride, resolution), Linear_BN(in_dim, nh_kd))
+ self.proj = nn.Sequential(activation(), Linear_BN(self.dh, out_dim))
+
+ self.stride = stride
+ self.resolution = resolution
+ points = list(itertools.product(range(resolution), range(resolution)))
+ points_ = list(
+ itertools.product(range(resolution_), range(resolution_)))
+
+ N = len(points)
+ N_ = len(points_)
+ attention_offsets = {}
+ idxs = []
+ i = 0
+ j = 0
+ for p1 in points_:
+ i += 1
+ for p2 in points:
+ j += 1
+ size = 1
+ offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
+ abs(p1[1] * stride - p2[1] + (size - 1) / 2))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = self.create_parameter(
+ shape=(num_heads, len(attention_offsets)),
+ default_initializer=zeros_,
+ attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
+
+ tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64')
+ self.register_buffer('attention_bias_idxs',
+ paddle.reshape(tensor_idxs_, [N_, N]))
+
+ @paddle.no_grad()
+ def train(self, mode=True):
+ if mode:
+ super().train()
+ else:
+ super().eval()
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ self.ab = cal_attention_biases(self.attention_biases,
+ self.attention_bias_idxs)
+
+ def forward(self, x):
+ self.training = True
+ B, N, C = x.shape
+ kv = self.kv(x)
+ kv = paddle.reshape(kv, [B, N, self.num_heads, -1])
+ k, v = paddle.split(kv, [self.key_dim, self.d], axis=3)
+ k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC
+ v = paddle.transpose(v, perm=[0, 2, 1, 3])
+ q = paddle.reshape(
+ self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim])
+ q = paddle.transpose(q, perm=[0, 2, 1, 3])
+
+ if self.training:
+ attention_biases = cal_attention_biases(self.attention_biases,
+ self.attention_bias_idxs)
+ else:
+ attention_biases = self.ab
+
+ attn = (q @paddle.transpose(
+ k, perm=[0, 1, 3, 2])) * self.scale + attention_biases
+ attn = F.softmax(attn)
+
+ x = paddle.reshape(
+ paddle.transpose(
+ (attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh])
+ x = self.proj(x)
+ return x
+
+
+class LeViT(nn.Layer):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ class_dim=1000,
+ embed_dim=[192],
+ key_dim=[64],
+ depth=[12],
+ num_heads=[3],
+ attn_ratio=[2],
+ mlp_ratio=[2],
+ hybrid_backbone=None,
+ down_ops=[],
+ attention_activation=nn.Hardswish,
+ mlp_activation=nn.Hardswish,
+ distillation=True,
+ drop_path=0):
+ super().__init__()
+
+ self.class_dim = class_dim
+ self.num_features = embed_dim[-1]
+ self.embed_dim = embed_dim
+ self.distillation = distillation
+
+ self.patch_embed = hybrid_backbone
+
+ self.blocks = []
+ down_ops.append([''])
+ resolution = img_size // patch_size
+ for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
+ zip(embed_dim, key_dim, depth, num_heads, attn_ratio,
+ mlp_ratio, down_ops)):
+ for _ in range(dpth):
+ self.blocks.append(
+ Residual(
+ Attention(
+ ed,
+ kd,
+ nh,
+ attn_ratio=ar,
+ activation=attention_activation,
+ resolution=resolution, ),
+ drop_path))
+ if mr > 0:
+ h = int(ed * mr)
+ self.blocks.append(
+ Residual(
+ nn.Sequential(
+ Linear_BN(ed, h),
+ mlp_activation(),
+ Linear_BN(
+ h, ed, bn_weight_init=0), ),
+ drop_path))
+ if do[0] == 'Subsample':
+ #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
+ resolution_ = (resolution - 1) // do[5] + 1
+ self.blocks.append(
+ AttentionSubsample(
+ *embed_dim[i:i + 2],
+ key_dim=do[1],
+ num_heads=do[2],
+ attn_ratio=do[3],
+ activation=attention_activation,
+ stride=do[5],
+ resolution=resolution,
+ resolution_=resolution_))
+ resolution = resolution_
+ if do[4] > 0: # mlp_ratio
+ h = int(embed_dim[i + 1] * do[4])
+ self.blocks.append(
+ Residual(
+ nn.Sequential(
+ Linear_BN(embed_dim[i + 1], h),
+ mlp_activation(),
+ Linear_BN(
+ h, embed_dim[i + 1], bn_weight_init=0), ),
+ drop_path))
+ self.blocks = nn.Sequential(*self.blocks)
+
+ # Classifier head
+ self.head = BN_Linear(embed_dim[-1],
+ class_dim) if class_dim > 0 else Identity()
+ if distillation:
+ self.head_dist = BN_Linear(
+ embed_dim[-1], class_dim) if class_dim > 0 else Identity()
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ x = x.flatten(2)
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = self.blocks(x)
+ x = x.mean(1)
+ if self.distillation:
+ x = self.head(x), self.head_dist(x)
+ if not self.training:
+ x = (x[0] + x[1]) / 2
+ else:
+ x = self.head(x)
+ return x
+
+
+def model_factory(C, D, X, N, drop_path, class_dim, distillation):
+ embed_dim = [int(x) for x in C.split('_')]
+ num_heads = [int(x) for x in N.split('_')]
+ depth = [int(x) for x in X.split('_')]
+ act = nn.Hardswish
+ model = LeViT(
+ patch_size=16,
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ key_dim=[D] * 3,
+ depth=depth,
+ attn_ratio=[2, 2, 2],
+ mlp_ratio=[2, 2, 2],
+ down_ops=[
+ #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
+ ['Subsample', D, embed_dim[0] // D, 4, 2, 2],
+ ['Subsample', D, embed_dim[1] // D, 4, 2, 2],
+ ],
+ attention_activation=act,
+ mlp_activation=act,
+ hybrid_backbone=b16(embed_dim[0], activation=act),
+ class_dim=class_dim,
+ drop_path=drop_path,
+ distillation=distillation)
+
+ return model
+
+
+specification = {
+ 'LeViT_128S': {
+ 'C': '128_256_384',
+ 'D': 16,
+ 'N': '4_6_8',
+ 'X': '2_3_4',
+ 'drop_path': 0
+ },
+ 'LeViT_128': {
+ 'C': '128_256_384',
+ 'D': 16,
+ 'N': '4_8_12',
+ 'X': '4_4_4',
+ 'drop_path': 0
+ },
+ 'LeViT_192': {
+ 'C': '192_288_384',
+ 'D': 32,
+ 'N': '3_5_6',
+ 'X': '4_4_4',
+ 'drop_path': 0
+ },
+ 'LeViT_256': {
+ 'C': '256_384_512',
+ 'D': 32,
+ 'N': '4_6_8',
+ 'X': '4_4_4',
+ 'drop_path': 0
+ },
+ 'LeViT_384': {
+ 'C': '384_512_768',
+ 'D': 32,
+ 'N': '6_9_12',
+ 'X': '4_4_4',
+ 'drop_path': 0.1
+ },
+}
+
+
+def LeViT_128S(class_dim=1000, distillation=True, pretrained=False):
+ return model_factory(
+ **specification['LeViT_128S'],
+ class_dim=class_dim,
+ distillation=distillation)
+
+
+def LeViT_128(class_dim=1000, distillation=True):
+ return model_factory(
+ **specification['LeViT_128'],
+ class_dim=class_dim,
+ distillation=distillation)
+
+
+def LeViT_192(class_dim=1000, distillation=True):
+ return model_factory(
+ **specification['LeViT_192'],
+ class_dim=class_dim,
+ distillation=distillation)
+
+
+def LeViT_256(class_dim=1000, distillation=False):
+ return model_factory(
+ **specification['LeViT_256'],
+ class_dim=class_dim,
+ distillation=distillation)
+
+
+def LeViT_384(class_dim=1000, distillation=True):
+ return model_factory(
+ **specification['LeViT_384'],
+ class_dim=class_dim,
+ distillation=distillation)
diff --git a/ppcls/arch/backbone/model_zoo/vision_transformer.py b/ppcls/arch/backbone/model_zoo/vision_transformer.py
index 32f198913d59014326957cc6fe7f9b325a59ef28..22cc0ad2790a15aeef4a34c107b8eb27308aebb4 100644
--- a/ppcls/arch/backbone/model_zoo/vision_transformer.py
+++ b/ppcls/arch/backbone/model_zoo/vision_transformer.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import Callable
+
import numpy as np
import paddle
import paddle.nn as nn
-from paddle.nn.initializer import TruncatedNormal, Constant
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
__all__ = [
"VisionTransformer", "ViT_small_patch16_224", "ViT_base_patch16_224",
@@ -25,6 +27,7 @@ __all__ = [
]
trunc_normal_ = TruncatedNormal(std=.02)
+normal_ = Normal
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
@@ -141,7 +144,13 @@ class Block(nn.Layer):
norm_layer='nn.LayerNorm',
epsilon=1e-5):
super().__init__()
- self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
+ elif isinstance(norm_layer, Callable):
+ self.norm1 = norm_layer(dim)
+ else:
+ raise TypeError(
+ "The norm_layer must be str or paddle.nn.layer.Layer class")
self.attn = Attention(
dim,
num_heads=num_heads,
@@ -151,7 +160,13 @@ class Block(nn.Layer):
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
- self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
+ elif isinstance(norm_layer, Callable):
+ self.norm2 = norm_layer(dim)
+ else:
+ raise TypeError(
+ "The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,