未验证 提交 a870f942 编写于 作者: L littletomatodonkey 提交者: GitHub

Update vision_transformer.py

上级 fb7c750c
...@@ -12,20 +12,18 @@ ...@@ -12,20 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant from paddle.nn.initializer import TruncatedNormal, Constant
__all__ = [ __all__ = [
"VisionTransformer", "VisionTransformer", "ViT_small_patch16_224", "ViT_base_patch16_224",
"ViT_small_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384", "ViT_large_patch16_224",
"ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384", "ViT_large_patch16_384", "ViT_large_patch32_384", "ViT_huge_patch16_224",
"ViT_large_patch16_224", "ViT_large_patch16_384", "ViT_large_patch32_384", "ViT_huge_patch32_384"
"ViT_huge_patch16_224", "ViT_huge_patch32_384"
] ]
trunc_normal_ = TruncatedNormal(std=.02) trunc_normal_ = TruncatedNormal(std=.02)
zeros_ = Constant(value=0.) zeros_ = Constant(value=0.)
ones_ = Constant(value=1.) ones_ = Constant(value=1.)
...@@ -43,12 +41,13 @@ def drop_path(x, drop_prob=0., training=False): ...@@ -43,12 +41,13 @@ def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training: if drop_prob == 0. or not training:
return x return x
keep_prob = paddle.to_tensor(1 - drop_prob) keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (x.shape[0],) + (1,) * (x.ndim - 1) shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor output = x.divide(keep_prob) * random_tensor
return output return output
class DropPath(nn.Layer): class DropPath(nn.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
""" """
...@@ -70,7 +69,12 @@ class Identity(nn.Layer): ...@@ -70,7 +69,12 @@ class Identity(nn.Layer):
class Mlp(nn.Layer): class Mlp(nn.Layer):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
...@@ -89,11 +93,17 @@ class Mlp(nn.Layer): ...@@ -89,11 +93,17 @@ class Mlp(nn.Layer):
class Attention(nn.Layer): class Attention(nn.Layer):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
...@@ -101,8 +111,9 @@ class Attention(nn.Layer): ...@@ -101,8 +111,9 @@ class Attention(nn.Layer):
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x): def forward(self, x):
B, N, C = x.shape # B= paddle.shape(x)[0]
qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // N, C = x.shape[1:]
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
self.num_heads)).transpose((2, 0, 3, 1, 4)) self.num_heads)).transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
...@@ -110,26 +121,42 @@ class Attention(nn.Layer): ...@@ -110,26 +121,42 @@ class Attention(nn.Layer):
attn = nn.functional.softmax(attn, axis=-1) attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((B, N, C)) x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
class Block(nn.Layer): class Block(nn.Layer):
def __init__(self,
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., dim,
drop_path=0., act_layer=nn.GELU, norm_layer='nn.LayerNorm', epsilon=1e-5): num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-5):
super().__init__() super().__init__()
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
self.attn = Attention( self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim,
act_layer=act_layer, drop=drop) hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x): def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.attn(self.norm1(x)))
...@@ -151,13 +178,13 @@ class PatchEmbed(nn.Layer): ...@@ -151,13 +178,13 @@ class PatchEmbed(nn.Layer):
self.patch_size = patch_size self.patch_size = patch_size
self.num_patches = num_patches self.num_patches = num_patches
self.proj = nn.Conv2D(in_chans, embed_dim, self.proj = nn.Conv2D(
kernel_size=patch_size, stride=patch_size) in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." "Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose((0, 2, 1)) x = self.proj(x).flatten(2).transpose((0, 2, 1))
return x return x
...@@ -167,16 +194,33 @@ class VisionTransformer(nn.Layer): ...@@ -167,16 +194,33 @@ class VisionTransformer(nn.Layer):
""" Vision Transformer with support for patch input """ Vision Transformer with support for patch input
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, class_dim=1000, embed_dim=768, depth=12, def __init__(self,
num_heads=12, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., img_size=224,
drop_path_rate=0., norm_layer='nn.LayerNorm', epsilon=1e-5, **args): patch_size=16,
in_chans=3,
class_dim=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
epsilon=1e-5,
**args):
super().__init__() super().__init__()
self.class_dim = class_dim self.class_dim = class_dim
self.num_features = self.embed_dim = embed_dim self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.pos_embed = self.create_parameter( self.pos_embed = self.create_parameter(
...@@ -187,23 +231,33 @@ class VisionTransformer(nn.Layer): ...@@ -187,23 +231,33 @@ class VisionTransformer(nn.Layer):
self.add_parameter("cls_token", self.cls_token) self.add_parameter("cls_token", self.cls_token)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x for x in paddle.linspace(0, drop_path_rate, depth)] dpr = np.linspace(0, drop_path_rate, depth)
self.blocks = nn.LayerList([ self.blocks = nn.LayerList([
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, epsilon=epsilon) num_heads=num_heads,
for i in range(depth)]) mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
epsilon=epsilon) for i in range(depth)
])
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon) self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
# Classifier head # Classifier head
self.head = nn.Linear( self.head = nn.Linear(embed_dim,
embed_dim, class_dim) if class_dim > 0 else Identity() class_dim) if class_dim > 0 else Identity()
trunc_normal_(self.pos_embed) # TODO(littletomatodonkey): same init in static mode
trunc_normal_(self.cls_token) if paddle.in_dynamic_mode():
self.apply(self._init_weights) trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
...@@ -215,7 +269,8 @@ class VisionTransformer(nn.Layer): ...@@ -215,7 +269,8 @@ class VisionTransformer(nn.Layer):
ones_(m.weight) ones_(m.weight)
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0] # B = x.shape[0]
B = paddle.shape(x)[0]
x = self.patch_embed(x) x = self.patch_embed(x)
cls_tokens = self.cls_token.expand((B, -1, -1)) cls_tokens = self.cls_token.expand((B, -1, -1))
x = paddle.concat((cls_tokens, x), axis=1) x = paddle.concat((cls_tokens, x), axis=1)
...@@ -234,59 +289,116 @@ class VisionTransformer(nn.Layer): ...@@ -234,59 +289,116 @@ class VisionTransformer(nn.Layer):
def ViT_small_patch16_224(**kwargs): def ViT_small_patch16_224(**kwargs):
model = VisionTransformer( model = VisionTransformer(
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, qk_scale=768**-0.5, **kwargs) patch_size=16,
embed_dim=768,
depth=8,
num_heads=8,
mlp_ratio=3,
qk_scale=768**-0.5,
**kwargs)
return model return model
def ViT_base_patch16_224(**kwargs): def ViT_base_patch16_224(**kwargs):
model = VisionTransformer( model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, patch_size=16,
epsilon=1e-6, **kwargs) embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
epsilon=1e-6,
**kwargs)
return model return model
def ViT_base_patch16_384(**kwargs): def ViT_base_patch16_384(**kwargs):
model = VisionTransformer( model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, img_size=384,
qkv_bias=True, epsilon=1e-6, **kwargs) patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
epsilon=1e-6,
**kwargs)
return model return model
def ViT_base_patch32_384(**kwargs): def ViT_base_patch32_384(**kwargs):
model = VisionTransformer( model = VisionTransformer(
img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, img_size=384,
qkv_bias=True, epsilon=1e-6, **kwargs) patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
epsilon=1e-6,
**kwargs)
return model return model
def ViT_large_patch16_224(**kwargs): def ViT_large_patch16_224(**kwargs):
model = VisionTransformer( model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, patch_size=16,
epsilon=1e-6, **kwargs) embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
epsilon=1e-6,
**kwargs)
return model return model
def ViT_large_patch16_384(**kwargs): def ViT_large_patch16_384(**kwargs):
model = VisionTransformer( model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, img_size=384,
qkv_bias=True, epsilon=1e-6, **kwargs) patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
epsilon=1e-6,
**kwargs)
return model return model
def ViT_large_patch32_384(**kwargs): def ViT_large_patch32_384(**kwargs):
model = VisionTransformer( model = VisionTransformer(
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, img_size=384,
qkv_bias=True, epsilon=1e-6, **kwargs) patch_size=32,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
epsilon=1e-6,
**kwargs)
return model return model
def ViT_huge_patch16_224(**kwargs): def ViT_huge_patch16_224(**kwargs):
model = VisionTransformer( model = VisionTransformer(
patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) patch_size=16,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
**kwargs)
return model return model
def ViT_huge_patch32_384(**kwargs): def ViT_huge_patch32_384(**kwargs):
model = VisionTransformer( model = VisionTransformer(
img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) img_size=384,
patch_size=32,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
**kwargs)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册