vision_transformer.py 14.0 KB
Newer Older
jm_12138's avatar
jm_12138 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
cuicheng01 已提交
15
# Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
G
gaotingquan 已提交
16
# reference: https://arxiv.org/abs/2010.11929
C
cuicheng01 已提交
17

18
from collections.abc import Callable
G
gaotingquan 已提交
19

20
import numpy as np
jm_12138's avatar
jm_12138 已提交
21 22
import paddle
import paddle.nn as nn
G
gaotingquan 已提交
23
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
jm_12138's avatar
jm_12138 已提交
24

C
cuicheng01 已提交
25 26 27
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    "ViT_small_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams",
    "ViT_base_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams",
    "ViT_base_patch16_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams",
    "ViT_base_patch32_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams",
    "ViT_large_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams",
    "ViT_large_patch16_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_384_pretrained.pdparams",
    "ViT_large_patch32_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch32_384_pretrained.pdparams",
}
C
cuicheng01 已提交
43 44 45

__all__ = list(MODEL_URLS.keys())

jm_12138's avatar
jm_12138 已提交
46
trunc_normal_ = TruncatedNormal(std=.02)
G
gaotingquan 已提交
47
normal_ = Normal
jm_12138's avatar
jm_12138 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)


def to_2tuple(x):
    return tuple([x] * 2)


def drop_path(x, drop_prob=0., training=False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0. or not training:
        return x
jm_12138's avatar
jm_12138 已提交
63
    keep_prob = paddle.to_tensor(1 - drop_prob)
64
    shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
H
HydrogenSulfate 已提交
65
    random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
66
    random_tensor = paddle.floor(random_tensor)  # binarize
jm_12138's avatar
jm_12138 已提交
67
    output = x.divide(keep_prob) * random_tensor
jm_12138's avatar
jm_12138 已提交
68 69
    return output

70

jm_12138's avatar
jm_12138 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
class DropPath(nn.Layer):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, input):
        return input


class Mlp(nn.Layer):
92 93 94 95 96 97
    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
jm_12138's avatar
jm_12138 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Layer):
116 117 118 119 120 121 122
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.):
jm_12138's avatar
jm_12138 已提交
123 124 125
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
126
        self.scale = qk_scale or head_dim**-0.5
jm_12138's avatar
jm_12138 已提交
127 128 129 130 131 132 133

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
134 135 136
        # B= paddle.shape(x)[0]
        N, C = x.shape[1:]
        qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
jm_12138's avatar
jm_12138 已提交
137 138 139 140 141 142 143
                                   self.num_heads)).transpose((2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
        attn = nn.functional.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

144
        x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
jm_12138's avatar
jm_12138 已提交
145 146 147 148 149 150
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Layer):
151 152 153 154 155 156 157 158 159 160 161 162
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 norm_layer='nn.LayerNorm',
                 epsilon=1e-5):
jm_12138's avatar
jm_12138 已提交
163
        super().__init__()
G
gaotingquan 已提交
164 165 166 167 168 169 170
        if isinstance(norm_layer, str):
            self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
        elif isinstance(norm_layer, Callable):
            self.norm1 = norm_layer(dim)
        else:
            raise TypeError(
                "The norm_layer must be str or paddle.nn.layer.Layer class")
jm_12138's avatar
jm_12138 已提交
171
        self.attn = Attention(
172 173 174 175 176 177
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop)
jm_12138's avatar
jm_12138 已提交
178 179
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
G
gaotingquan 已提交
180 181 182 183 184 185 186
        if isinstance(norm_layer, str):
            self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
        elif isinstance(norm_layer, Callable):
            self.norm2 = norm_layer(dim)
        else:
            raise TypeError(
                "The norm_layer must be str or paddle.nn.layer.Layer class")
jm_12138's avatar
jm_12138 已提交
187
        mlp_hidden_dim = int(dim * mlp_ratio)
188 189 190 191
        self.mlp = Mlp(in_features=dim,
                       hidden_features=mlp_hidden_dim,
                       act_layer=act_layer,
                       drop=drop)
jm_12138's avatar
jm_12138 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Layer):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * \
            (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

213 214
        self.proj = nn.Conv2D(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
jm_12138's avatar
jm_12138 已提交
215 216 217 218

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
219
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
jm_12138's avatar
jm_12138 已提交
220 221 222 223 224 225 226 227 228

        x = self.proj(x).flatten(2).transpose((0, 2, 1))
        return x


class VisionTransformer(nn.Layer):
    """ Vision Transformer with support for patch input
    """

229 230 231 232
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
littletomatodonkey's avatar
littletomatodonkey 已提交
233
                 class_num=1000,
234 235 236 237 238 239 240 241 242 243 244
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer='nn.LayerNorm',
                 epsilon=1e-5,
C
cuicheng01 已提交
245
                 **kwargs):
jm_12138's avatar
jm_12138 已提交
246
        super().__init__()
littletomatodonkey's avatar
littletomatodonkey 已提交
247
        self.class_num = class_num
jm_12138's avatar
jm_12138 已提交
248 249 250 251

        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbed(
252 253 254 255
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim)
jm_12138's avatar
jm_12138 已提交
256 257 258 259 260 261 262 263 264 265
        num_patches = self.patch_embed.num_patches

        self.pos_embed = self.create_parameter(
            shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_)
        self.add_parameter("pos_embed", self.pos_embed)
        self.cls_token = self.create_parameter(
            shape=(1, 1, embed_dim), default_initializer=zeros_)
        self.add_parameter("cls_token", self.cls_token)
        self.pos_drop = nn.Dropout(p=drop_rate)

266
        dpr = np.linspace(0, drop_path_rate, depth)
jm_12138's avatar
jm_12138 已提交
267 268 269

        self.blocks = nn.LayerList([
            Block(
270 271 272 273 274 275 276 277 278 279 280
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                epsilon=epsilon) for i in range(depth)
        ])
jm_12138's avatar
jm_12138 已提交
281 282 283 284

        self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)

        # Classifier head
285
        self.head = nn.Linear(embed_dim,
littletomatodonkey's avatar
littletomatodonkey 已提交
286
                              class_num) if class_num > 0 else Identity()
jm_12138's avatar
jm_12138 已提交
287

288 289 290
        trunc_normal_(self.pos_embed)
        trunc_normal_(self.cls_token)
        self.apply(self._init_weights)
jm_12138's avatar
jm_12138 已提交
291 292 293 294 295 296 297 298 299 300 301

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward_features(self, x):
302 303
        # B = x.shape[0]
        B = paddle.shape(x)[0]
jm_12138's avatar
jm_12138 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand((B, -1, -1))
        x = paddle.concat((cls_tokens, x), axis=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


C
cuicheng01 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )


333
def ViT_small_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
334
    model = VisionTransformer(
335 336 337 338 339 340 341
        patch_size=16,
        embed_dim=768,
        depth=8,
        num_heads=8,
        mlp_ratio=3,
        qk_scale=768**-0.5,
        **kwargs)
342 343 344 345 346
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_small_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
347 348 349
    return model


350
def ViT_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
351
    model = VisionTransformer(
352 353 354 355 356 357 358 359
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
360 361 362 363 364
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_base_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
365 366 367
    return model


368
def ViT_base_patch16_384(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
369
    model = VisionTransformer(
370 371 372 373 374 375 376 377 378
        img_size=384,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
379 380 381 382 383
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_base_patch16_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
384 385 386
    return model


387
def ViT_base_patch32_384(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
388
    model = VisionTransformer(
389 390 391 392 393 394 395 396 397
        img_size=384,
        patch_size=32,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
398 399 400 401 402
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_base_patch32_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
403 404 405
    return model


406
def ViT_large_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
407
    model = VisionTransformer(
408 409 410 411 412 413 414 415
        patch_size=16,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
416 417 418 419 420
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_large_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
421 422 423
    return model


424
def ViT_large_patch16_384(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
425
    model = VisionTransformer(
426 427 428 429 430 431 432 433 434
        img_size=384,
        patch_size=16,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
435 436 437 438 439
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_large_patch16_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
440 441 442
    return model


443
def ViT_large_patch32_384(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
444
    model = VisionTransformer(
445 446 447 448 449 450 451 452 453
        img_size=384,
        patch_size=32,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
454 455 456 457 458
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["ViT_large_patch32_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
459
    return model