未验证 提交 ef08cc04 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #825 from TingquanGao/reg/add_twins_levit

Add LeViT and Twins
# LeViT
## 概述
LeViT是一种快速推理的、用于图像分类任务的混合神经网络。其设计之初考虑了网络模型在不同的硬件平台上的性能,因此能够更好地反映普遍应用的真实场景。通过大量实验,作者找到了卷积神经网络与Transformer体系更好的结合方式,并且提出了attention-based方法,用于整合Transformer中的位置信息编码。[论文地址](https://arxiv.org/abs/2104.01136)
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(M) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| LeViT-128S | 0.7621 | 0.9277 | 0.766 | 0.929 | 305 | 7.8 |
| LeViT-128 | 0.7833 | 0.9378 | 0.786 | 0.940 | 406 | 9.2 |
| LeViT-192 | 0.7963 | 0.9460 | 0.800 | 0.947 | 658 | 11 |
| LeViT-256 | 0.8085 | 0.9497 | 0.816 | 0.954 | 1120 | 19 |
| LeViT-384 | 0.8234 | 0.9587 | 0.826 | 0.960 | 2353 | 39 |
**注**:与Reference的精度差异源于数据预处理不同。
# Twins
## 概述
Twins网络包括Twins-PCPVT和Twins-SVT,其重点对空间注意力机制进行了精心设计,得到了简单却更为有效的方案。由于该体系结构仅涉及矩阵乘法,而目前的深度学习框架中对矩阵乘法有较高的优化程度,因此该体系结构十分高效且易于实现。并且,该体系结构在图像分类、目标检测和语义分割等多种下游视觉任务中都能够取得优异的性能。[论文地址](https://arxiv.org/abs/2104.13840)
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| pcpvt_small | 0.8082 | 0.9552 | 0.812 | - | 3.7 | 24.1 |
| pcpvt_base | 0.8242 | 0.9619 | 0.827 | - | 6.4 | 43.8 |
| pcpvt_large | 0.8273 | 0.9650 | 0.831 | - | 9.5 | 60.9 |
| alt_gvt_small | 0.8140 | 0.9546 | 0.817 | - | 2.8 | 24 |
| alt_gvt_base | 0.8294 | 0.9621 | 0.832 | - | 8.3 | 56 |
| alt_gvt_large | 0.8331 | 0.9642 | 0.837 | - | 14.8 | 99.2 |
**注**:与Reference的精度差异源于数据预处理不同。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -47,6 +47,8 @@ from ppcls.arch.backbone.model_zoo.distillation_models import ResNet50_vd_distil
from ppcls.arch.backbone.model_zoo.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L
from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
from ppcls.arch.backbone.model_zoo.gvt import pcpvt_small, pcpvt_base, pcpvt_large, alt_gvt_small, alt_gvt_base, alt_gvt_large
from ppcls.arch.backbone.model_zoo.levit import LeViT_128S, LeViT_128, LeViT_192, LeViT_256, LeViT_384
from ppcls.arch.backbone.model_zoo.dla import DLA34, DLA46_c, DLA46x_c, DLA60, DLA60x, DLA60x_c, DLA102, DLA102x, DLA102x2, DLA169
from ppcls.arch.backbone.model_zoo.rednet import RedNet26, RedNet38, RedNet50, RedNet101, RedNet152
from ppcls.arch.backbone.model_zoo.tnt import TNT_small
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.regularizer import L2Decay
from .vision_transformer import trunc_normal_, normal_, zeros_, ones_, to_2tuple, DropPath, Identity, Mlp
from .vision_transformer import Block as ViTBlock
__all__ = [
"CPVTV2", "PCPVT", "ALTGVT", "pcpvt_small", "pcpvt_base", "pcpvt_large",
"alt_gvt_small", "alt_gvt_base", "alt_gvt_large"
]
class GroupAttention(nn.Layer):
"""LSA: self attention within a group.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
ws=1):
super().__init__()
if ws == 1:
raise Exception(f"ws {ws} should not be 1")
if dim % num_heads != 0:
raise Exception(
f"dim {dim} should be divided by num_heads {num_heads}.")
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.ws = ws
def forward(self, x, H, W):
B, N, C = x.shape
h_group, w_group = H // self.ws, W // self.ws
total_groups = h_group * w_group
x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose(
[0, 1, 3, 2, 4, 5])
qkv = self.qkv(x).reshape(
[B, total_groups, -1, 3, self.num_heads,
C // self.num_heads]).transpose([3, 0, 1, 4, 2, 5])
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale
attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn)
attn = (attn @v).transpose([0, 1, 3, 2, 4]).reshape(
[B, h_group, w_group, self.ws, self.ws, C])
x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class Attention(nn.Layer):
"""GSA: using a key to summarize the information for a group to be efficient.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2D(
dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(
[B, N, self.num_heads, C // self.num_heads]).transpose(
[0, 2, 1, 3])
if self.sr_ratio > 1:
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
x_ = self.norm(x_)
kv = self.kv(x_).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
else:
kv = self.kv(x).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn)
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class SBlock(ViTBlock):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer)
def forward(self, x, H, W):
return super().forward(x)
class GroupBlock(ViTBlock):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1,
ws=1):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer)
del self.attn
if ws == 1:
self.attn = Attention(dim, num_heads, qkv_bias, qk_scale,
attn_drop, drop, sr_ratio)
else:
self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale,
attn_drop, drop, ws)
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Layer):
""" Image to Patch Embedding.
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
if img_size % patch_size != 0:
raise Exception(
f"img_size {img_size} should be divided by patch_size {patch_size}."
)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[
1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2D(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose([0, 2, 1])
x = self.norm(x)
H, W = H // self.patch_size[0], W // self.patch_size[1]
return x, (H, W)
# borrow from PVT https://github.com/whai362/PVT.git
class PyramidVisionTransformer(nn.Layer):
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
block_cls=Block):
super().__init__()
self.num_classes = num_classes
self.depths = depths
# patch_embed
self.patch_embeds = nn.LayerList()
self.pos_embeds = nn.ParameterList()
self.pos_drops = nn.LayerList()
self.blocks = nn.LayerList()
for i in range(len(depths)):
if i == 0:
self.patch_embeds.append(
PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))
else:
self.patch_embeds.append(
PatchEmbed(img_size // patch_size // 2**(i - 1), 2,
embed_dims[i - 1], embed_dims[i]))
patch_num = self.patch_embeds[i].num_patches + 1 if i == len(
embed_dims) - 1 else self.patch_embeds[i].num_patches
self.pos_embeds.append(
self.create_parameter(
shape=[1, patch_num, embed_dims[i]],
default_initializer=zeros_))
self.add_parameter(f"pos_embeds_{i}", self.pos_embeds[i])
self.pos_drops.append(nn.Dropout(p=drop_rate))
dpr = [
x.numpy()[0]
for x in paddle.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
for k in range(len(depths)):
_block = nn.LayerList([
block_cls(
dim=embed_dims[k],
num_heads=num_heads[k],
mlp_ratio=mlp_ratios[k],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + i],
norm_layer=norm_layer,
sr_ratio=sr_ratios[k]) for i in range(depths[k])
])
self.blocks.append(_block)
cur += depths[k]
self.norm = norm_layer(embed_dims[-1])
# cls_token
self.cls_token = self.create_parameter(
shape=[1, 1, embed_dims[-1]],
default_initializer=zeros_,
attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
self.add_parameter("cls_token", self.cls_token)
# classification head
self.head = nn.Linear(embed_dims[-1],
num_classes) if num_classes > 0 else Identity()
# init weights
for pos_emb in self.pos_embeds:
trunc_normal_(pos_emb)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
B = x.shape[0]
for i in range(len(self.depths)):
x, (H, W) = self.patch_embeds[i](x)
if i == len(self.depths) - 1:
cls_tokens = self.cls_token.expand([B, -1, -1])
x = paddle.concat([cls_tokens, x], dim=1)
x = x + self.pos_embeds[i]
x = self.pos_drops[i](x)
for blk in self.blocks[i]:
x = blk(x, H, W)
if i < len(self.depths) - 1:
x = x.reshape([B, H, W, -1]).transpose(
[0, 3, 1, 2]).contiguous()
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
# PEG from https://arxiv.org/abs/2102.10882
class PosCNN(nn.Layer):
def __init__(self, in_chans, embed_dim=768, s=1):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2D(
in_chans,
embed_dim,
3,
s,
1,
bias_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)),
groups=embed_dim,
weight_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)), ))
self.s = s
def forward(self, x, H, W):
B, N, C = x.shape
feat_token = x
cnn_feat = feat_token.transpose([0, 2, 1]).reshape([B, C, H, W])
if self.s == 1:
x = self.proj(cnn_feat) + cnn_feat
else:
x = self.proj(cnn_feat)
x = x.flatten(2).transpose([0, 2, 1])
return x
class CPVTV2(PyramidVisionTransformer):
"""
Use useful results from CPVT. PEG and GAP.
Therefore, cls token is no longer required.
PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution
changes during the training (such as segmentation, detection)
"""
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
block_cls=Block):
super().__init__(img_size, patch_size, in_chans, num_classes,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
depths, sr_ratios, block_cls)
del self.pos_embeds
del self.cls_token
self.pos_block = nn.LayerList(
[PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims])
self.apply(self._init_weights)
def _init_weights(self, m):
import math
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
elif isinstance(m, nn.Conv2D):
fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
fan_out //= m._groups
normal_(0, math.sqrt(2.0 / fan_out))(m.weight)
if m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2D):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward_features(self, x):
B = x.shape[0]
for i in range(len(self.depths)):
x, (H, W) = self.patch_embeds[i](x)
x = self.pos_drops[i](x)
for j, blk in enumerate(self.blocks[i]):
x = blk(x, H, W)
if j == 0:
x = self.pos_block[i](x, H, W) # PEG here
if i < len(self.depths) - 1:
x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
x = self.norm(x)
return x.mean(axis=1) # GAP here
class PCPVT(CPVTV2):
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256],
num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4],
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[4, 4, 4],
sr_ratios=[4, 2, 1],
block_cls=SBlock):
super().__init__(img_size, patch_size, in_chans, num_classes,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
depths, sr_ratios, block_cls)
class ALTGVT(PCPVT):
"""
alias Twins-SVT
"""
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
class_dim=1000,
embed_dims=[64, 128, 256],
num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4],
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[4, 4, 4],
sr_ratios=[4, 2, 1],
block_cls=GroupBlock,
wss=[7, 7, 7]):
super().__init__(img_size, patch_size, in_chans, class_dim, embed_dims,
num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
attn_drop_rate, drop_path_rate, norm_layer, depths,
sr_ratios, block_cls)
del self.blocks
self.wss = wss
# transformer encoder
dpr = [
x.numpy()[0]
for x in paddle.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
self.blocks = nn.LayerList()
for k in range(len(depths)):
_block = nn.LayerList([
block_cls(
dim=embed_dims[k],
num_heads=num_heads[k],
mlp_ratio=mlp_ratios[k],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + i],
norm_layer=norm_layer,
sr_ratio=sr_ratios[k],
ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])
])
self.blocks.append(_block)
cur += depths[k]
self.apply(self._init_weights)
def pcpvt_small(pretrained=False, **kwargs):
model = CPVTV2(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
**kwargs)
return model
def pcpvt_base(pretrained=False, **kwargs):
model = CPVTV2(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
depths=[3, 4, 18, 3],
sr_ratios=[8, 4, 2, 1],
**kwargs)
return model
def pcpvt_large(pretrained=False, **kwargs):
model = CPVTV2(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
depths=[3, 8, 27, 3],
sr_ratios=[8, 4, 2, 1],
**kwargs)
return model
def alt_gvt_small(pretrained=False, **kwargs):
model = ALTGVT(
patch_size=4,
embed_dims=[64, 128, 256, 512],
num_heads=[2, 4, 8, 16],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
depths=[2, 2, 10, 4],
wss=[7, 7, 7, 7],
sr_ratios=[8, 4, 2, 1],
**kwargs)
return model
def alt_gvt_base(pretrained=False, **args):
model = ALTGVT(
patch_size=4,
embed_dims=[96, 192, 384, 768],
num_heads=[3, 6, 12, 24],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
depths=[2, 2, 18, 2],
wss=[7, 7, 7, 7],
sr_ratios=[8, 4, 2, 1],
**args)
return model
def alt_gvt_large(pretrained=False, **kwargs):
model = ALTGVT(
patch_size=4,
embed_dims=[128, 256, 512, 1024],
num_heads=[4, 8, 16, 32],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
depths=[2, 2, 18, 2],
wss=[7, 7, 7, 7],
sr_ratios=[8, 4, 2, 1],
**kwargs)
return model
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import math
import warnings
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant
from paddle.regularizer import L2Decay
from .vision_transformer import trunc_normal_, zeros_, ones_, Identity
__all__ = ["LeViT_128S", "LeViT_128", "LeViT_192", "LeViT_256", "LeViT_384"]
def cal_attention_biases(attention_biases, attention_bias_idxs):
gather_list = []
attention_bias_t = paddle.transpose(attention_biases, (1, 0))
for idx in attention_bias_idxs:
gather = paddle.gather(attention_bias_t, idx)
gather_list.append(gather)
shape0, shape1 = attention_bias_idxs.shape
return paddle.transpose(paddle.concat(gather_list), (1, 0)).reshape(
(0, shape0, shape1))
class Conv2d_BN(nn.Sequential):
def __init__(self,
a,
b,
ks=1,
stride=1,
pad=0,
dilation=1,
groups=1,
bn_weight_init=1,
resolution=-10000):
super().__init__()
self.add_sublayer(
'c',
nn.Conv2D(
a, b, ks, stride, pad, dilation, groups, bias_attr=False))
bn = nn.BatchNorm2D(b)
ones_(bn.weight)
zeros_(bn.bias)
self.add_sublayer('bn', bn)
class Linear_BN(nn.Sequential):
def __init__(self, a, b, bn_weight_init=1):
super().__init__()
self.add_sublayer('c', nn.Linear(a, b, bias_attr=False))
bn = nn.BatchNorm1D(b)
ones_(bn.weight)
zeros_(bn.bias)
self.add_sublayer('bn', bn)
def forward(self, x):
l, bn = self._sub_layers.values()
x = l(x)
return paddle.reshape(bn(x.flatten(0, 1)), x.shape)
class BN_Linear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_sublayer('bn', nn.BatchNorm1D(a))
l = nn.Linear(a, b, bias_attr=bias)
trunc_normal_(l.weight)
if bias:
zeros_(l.bias)
self.add_sublayer('l', l)
def b16(n, activation, resolution=224):
return nn.Sequential(
Conv2d_BN(
3, n // 8, 3, 2, 1, resolution=resolution),
activation(),
Conv2d_BN(
n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
Conv2d_BN(
n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
Conv2d_BN(
n // 2, n, 3, 2, 1, resolution=resolution // 8))
class Residual(nn.Layer):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * paddle.rand(
x.size(0), 1, 1,
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Attention(nn.Layer):
def __init__(self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
activation=None,
resolution=14):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.h = self.dh + nh_kd * 2
self.qkv = Linear_BN(dim, self.h)
self.proj = nn.Sequential(
activation(), Linear_BN(
self.dh, dim, bn_weight_init=0))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = self.create_parameter(
shape=(num_heads, len(attention_offsets)),
default_initializer=zeros_,
attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
tensor_idxs = paddle.to_tensor(idxs, dtype='int64')
self.register_buffer('attention_bias_idxs',
paddle.reshape(tensor_idxs, [N, N]))
@paddle.no_grad()
def train(self, mode=True):
if mode:
super().train()
else:
super().eval()
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = cal_attention_biases(self.attention_biases,
self.attention_bias_idxs)
def forward(self, x):
self.training = True
B, N, C = x.shape
qkv = self.qkv(x)
qkv = paddle.reshape(qkv,
[B, N, self.num_heads, self.h // self.num_heads])
q, k, v = paddle.split(
qkv, [self.key_dim, self.key_dim, self.d], axis=3)
q = paddle.transpose(q, perm=[0, 2, 1, 3])
k = paddle.transpose(k, perm=[0, 2, 1, 3])
v = paddle.transpose(v, perm=[0, 2, 1, 3])
k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2])
if self.training:
attention_biases = cal_attention_biases(self.attention_biases,
self.attention_bias_idxs)
else:
attention_biases = self.ab
attn = ((q @k_transpose) * self.scale + attention_biases)
attn = F.softmax(attn)
x = paddle.transpose(attn @v, perm=[0, 2, 1, 3])
x = paddle.reshape(x, [B, N, self.dh])
x = self.proj(x)
return x
class Subsample(nn.Layer):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = paddle.reshape(x, [B, self.resolution, self.resolution,
C])[:, ::self.stride, ::self.stride]
x = paddle.reshape(x, [B, -1, C])
return x
class AttentionSubsample(nn.Layer):
def __init__(self,
in_dim,
out_dim,
key_dim,
num_heads=8,
attn_ratio=2,
activation=None,
stride=2,
resolution=14,
resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_**2
self.training = True
h = self.dh + nh_kd
self.kv = Linear_BN(in_dim, h)
self.q = nn.Sequential(
Subsample(stride, resolution), Linear_BN(in_dim, nh_kd))
self.proj = nn.Sequential(activation(), Linear_BN(self.dh, out_dim))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(
itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
i = 0
j = 0
for p1 in points_:
i += 1
for p2 in points:
j += 1
size = 1
offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = self.create_parameter(
shape=(num_heads, len(attention_offsets)),
default_initializer=zeros_,
attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64')
self.register_buffer('attention_bias_idxs',
paddle.reshape(tensor_idxs_, [N_, N]))
@paddle.no_grad()
def train(self, mode=True):
if mode:
super().train()
else:
super().eval()
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = cal_attention_biases(self.attention_biases,
self.attention_bias_idxs)
def forward(self, x):
self.training = True
B, N, C = x.shape
kv = self.kv(x)
kv = paddle.reshape(kv, [B, N, self.num_heads, -1])
k, v = paddle.split(kv, [self.key_dim, self.d], axis=3)
k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC
v = paddle.transpose(v, perm=[0, 2, 1, 3])
q = paddle.reshape(
self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim])
q = paddle.transpose(q, perm=[0, 2, 1, 3])
if self.training:
attention_biases = cal_attention_biases(self.attention_biases,
self.attention_bias_idxs)
else:
attention_biases = self.ab
attn = (q @paddle.transpose(
k, perm=[0, 1, 3, 2])) * self.scale + attention_biases
attn = F.softmax(attn)
x = paddle.reshape(
paddle.transpose(
(attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh])
x = self.proj(x)
return x
class LeViT(nn.Layer):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
class_dim=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attention_activation=nn.Hardswish,
mlp_activation=nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
self.class_dim = class_dim
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio,
mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(
ed,
kd,
nh,
attn_ratio=ar,
activation=attention_activation,
resolution=resolution, ),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(
nn.Sequential(
Linear_BN(ed, h),
mlp_activation(),
Linear_BN(
h, ed, bn_weight_init=0), ),
drop_path))
if do[0] == 'Subsample':
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2],
key_dim=do[1],
num_heads=do[2],
attn_ratio=do[3],
activation=attention_activation,
stride=do[5],
resolution=resolution,
resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(
nn.Sequential(
Linear_BN(embed_dim[i + 1], h),
mlp_activation(),
Linear_BN(
h, embed_dim[i + 1], bn_weight_init=0), ),
drop_path))
self.blocks = nn.Sequential(*self.blocks)
# Classifier head
self.head = BN_Linear(embed_dim[-1],
class_dim) if class_dim > 0 else Identity()
if distillation:
self.head_dist = BN_Linear(
embed_dim[-1], class_dim) if class_dim > 0 else Identity()
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2)
x = paddle.transpose(x, perm=[0, 2, 1])
x = self.blocks(x)
x = x.mean(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, class_dim, distillation):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = nn.Hardswish
model = LeViT(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attention_activation=act,
mlp_activation=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
class_dim=class_dim,
drop_path=drop_path,
distillation=distillation)
return model
specification = {
'LeViT_128S': {
'C': '128_256_384',
'D': 16,
'N': '4_6_8',
'X': '2_3_4',
'drop_path': 0
},
'LeViT_128': {
'C': '128_256_384',
'D': 16,
'N': '4_8_12',
'X': '4_4_4',
'drop_path': 0
},
'LeViT_192': {
'C': '192_288_384',
'D': 32,
'N': '3_5_6',
'X': '4_4_4',
'drop_path': 0
},
'LeViT_256': {
'C': '256_384_512',
'D': 32,
'N': '4_6_8',
'X': '4_4_4',
'drop_path': 0
},
'LeViT_384': {
'C': '384_512_768',
'D': 32,
'N': '6_9_12',
'X': '4_4_4',
'drop_path': 0.1
},
}
def LeViT_128S(class_dim=1000, distillation=True, pretrained=False):
return model_factory(
**specification['LeViT_128S'],
class_dim=class_dim,
distillation=distillation)
def LeViT_128(class_dim=1000, distillation=True):
return model_factory(
**specification['LeViT_128'],
class_dim=class_dim,
distillation=distillation)
def LeViT_192(class_dim=1000, distillation=True):
return model_factory(
**specification['LeViT_192'],
class_dim=class_dim,
distillation=distillation)
def LeViT_256(class_dim=1000, distillation=False):
return model_factory(
**specification['LeViT_256'],
class_dim=class_dim,
distillation=distillation)
def LeViT_384(class_dim=1000, distillation=True):
return model_factory(
**specification['LeViT_384'],
class_dim=class_dim,
distillation=distillation)
......@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Callable
import numpy as np
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
__all__ = [
"VisionTransformer", "ViT_small_patch16_224", "ViT_base_patch16_224",
......@@ -25,6 +27,7 @@ __all__ = [
]
trunc_normal_ = TruncatedNormal(std=.02)
normal_ = Normal
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
......@@ -141,7 +144,13 @@ class Block(nn.Layer):
norm_layer='nn.LayerNorm',
epsilon=1e-5):
super().__init__()
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm1 = norm_layer(dim)
else:
raise TypeError(
"The norm_layer must be str or paddle.nn.layer.Layer class")
self.attn = Attention(
dim,
num_heads=num_heads,
......@@ -151,7 +160,13 @@ class Block(nn.Layer):
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm2 = norm_layer(dim)
else:
raise TypeError(
"The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册