gvt.py 23.1 KB
Newer Older
G
gaotingquan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.regularizer import L2Decay

from .vision_transformer import trunc_normal_, normal_, zeros_, ones_, to_2tuple, DropPath, Identity, Mlp
from .vision_transformer import Block as ViTBlock

C
cuicheng01 已提交
25 26 27
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
littletomatodonkey's avatar
littletomatodonkey 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40
    "pcpvt_small":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/pcpvt_small_pretrained.pdparams",
    "pcpvt_base":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/pcpvt_base_pretrained.pdparams",
    "pcpvt_large":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/pcpvt_large_pretrained.pdparams",
    "alt_gvt_small":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/alt_gvt_small_pretrained.pdparams",
    "alt_gvt_base":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/alt_gvt_base_pretrained.pdparams",
    "alt_gvt_large":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/alt_gvt_large_pretrained.pdparams"
}
C
cuicheng01 已提交
41 42 43

__all__ = list(MODEL_URLS.keys())

G
gaotingquan 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

class GroupAttention(nn.Layer):
    """LSA: self attention within a group.
    """

    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 ws=1):
        super().__init__()
        if ws == 1:
C
cuicheng01 已提交
59
            raise Exception("ws {ws} should not be 1")
G
gaotingquan 已提交
60 61
        if dim % num_heads != 0:
            raise Exception(
C
cuicheng01 已提交
62
                "dim {dim} should be divided by num_heads {num_heads}.")
G
gaotingquan 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.ws = ws

    def forward(self, x, H, W):
        B, N, C = x.shape
        h_group, w_group = H // self.ws, W // self.ws
        total_groups = h_group * w_group
        x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose(
            [0, 1, 3, 2, 4, 5])
T
Tingquan Gao 已提交
81 82 83
        qkv = self.qkv(x).reshape([
            B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads
        ]).transpose([3, 0, 1, 4, 2, 5])
G
gaotingquan 已提交
84
        q, k, v = qkv[0], qkv[1], qkv[2]
G
gaotingquan 已提交
85
        attn = paddle.matmul(q, k.transpose([0, 1, 2, 4, 3])) * self.scale
G
gaotingquan 已提交
86 87 88

        attn = nn.Softmax(axis=-1)(attn)
        attn = self.attn_drop(attn)
G
gaotingquan 已提交
89
        attn = paddle.matmul(attn, v).transpose([0, 1, 3, 2, 4]).reshape(
G
gaotingquan 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
            [B, h_group, w_group, self.ws, self.ws, C])

        x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C])
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Attention(nn.Layer):
    """GSA: using a key to summarize the information for a group to be efficient.
    """

    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2D(
                dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(
            [B, N, self.num_heads, C // self.num_heads]).transpose(
                [0, 2, 1, 3])

        if self.sr_ratio > 1:
            x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
T
Tingquan Gao 已提交
138 139
            tmp_n = H * W // self.sr_ratio**2
            x_ = self.sr(x_).reshape([B, C, tmp_n]).transpose([0, 2, 1])
G
gaotingquan 已提交
140 141
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(
T
Tingquan Gao 已提交
142
                [B, tmp_n, 2, self.num_heads, C // self.num_heads]).transpose(
G
gaotingquan 已提交
143 144 145
                    [2, 0, 3, 1, 4])
        else:
            kv = self.kv(x).reshape(
T
Tingquan Gao 已提交
146
                [B, N, 2, self.num_heads, C // self.num_heads]).transpose(
G
gaotingquan 已提交
147 148 149
                    [2, 0, 3, 1, 4])
        k, v = kv[0], kv[1]

G
gaotingquan 已提交
150
        attn = paddle.matmul(q, k.transpose([0, 1, 3, 2])) * self.scale
G
gaotingquan 已提交
151 152 153
        attn = nn.Softmax(axis=-1)(attn)
        attn = self.attn_drop(attn)

G
gaotingquan 已提交
154
        x = paddle.matmul(attn, v).transpose([0, 2, 1, 3]).reshape([B, N, C])
G
gaotingquan 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Layer):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 sr_ratio=1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            sr_ratio=sr_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim,
                       hidden_features=mlp_hidden_dim,
                       act_layer=act_layer,
                       drop=drop)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class SBlock(ViTBlock):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 sr_ratio=1):
        super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
                         attn_drop, drop_path, act_layer, norm_layer)

    def forward(self, x, H, W):
        return super().forward(x)


class GroupBlock(ViTBlock):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 sr_ratio=1,
                 ws=1):
        super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
                         attn_drop, drop_path, act_layer, norm_layer)
        del self.attn
        if ws == 1:
            self.attn = Attention(dim, num_heads, qkv_bias, qk_scale,
                                  attn_drop, drop, sr_ratio)
        else:
            self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale,
                                       attn_drop, drop, ws)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Layer):
    """ Image to Patch Embedding.
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        if img_size % patch_size != 0:
            raise Exception(
                f"img_size {img_size} should be divided by patch_size {patch_size}."
            )

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[
            1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2D(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose([0, 2, 1])
        x = self.norm(x)
        H, W = H // self.patch_size[0], W // self.patch_size[1]
        return x, (H, W)


# borrow from PVT https://github.com/whai362/PVT.git
class PyramidVisionTransformer(nn.Layer):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
C
cuicheng01 已提交
284
                 class_num=1000,
G
gaotingquan 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297
                 embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8],
                 mlp_ratios=[4, 4, 4, 4],
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3],
                 sr_ratios=[8, 4, 2, 1],
                 block_cls=Block):
        super().__init__()
C
cuicheng01 已提交
298
        self.class_num = class_num
G
gaotingquan 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
        self.depths = depths

        # patch_embed
        self.patch_embeds = nn.LayerList()
        self.pos_embeds = nn.ParameterList()
        self.pos_drops = nn.LayerList()
        self.blocks = nn.LayerList()

        for i in range(len(depths)):
            if i == 0:
                self.patch_embeds.append(
                    PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))
            else:
                self.patch_embeds.append(
                    PatchEmbed(img_size // patch_size // 2**(i - 1), 2,
                               embed_dims[i - 1], embed_dims[i]))
            patch_num = self.patch_embeds[i].num_patches + 1 if i == len(
                embed_dims) - 1 else self.patch_embeds[i].num_patches
            self.pos_embeds.append(
                self.create_parameter(
                    shape=[1, patch_num, embed_dims[i]],
                    default_initializer=zeros_))
            self.pos_drops.append(nn.Dropout(p=drop_rate))

        dpr = [
            x.numpy()[0]
            for x in paddle.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule

        cur = 0
        for k in range(len(depths)):
            _block = nn.LayerList([
                block_cls(
                    dim=embed_dims[k],
                    num_heads=num_heads[k],
                    mlp_ratio=mlp_ratios[k],
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[cur + i],
                    norm_layer=norm_layer,
                    sr_ratio=sr_ratios[k]) for i in range(depths[k])
            ])
            self.blocks.append(_block)
            cur += depths[k]

        self.norm = norm_layer(embed_dims[-1])

        # cls_token
        self.cls_token = self.create_parameter(
            shape=[1, 1, embed_dims[-1]],
            default_initializer=zeros_,
            attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))

        # classification head
        self.head = nn.Linear(embed_dims[-1],
C
cuicheng01 已提交
356
                              class_num) if class_num > 0 else Identity()
G
gaotingquan 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434

        # init weights
        for pos_emb in self.pos_embeds:
            trunc_normal_(pos_emb)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward_features(self, x):
        B = x.shape[0]
        for i in range(len(self.depths)):
            x, (H, W) = self.patch_embeds[i](x)
            if i == len(self.depths) - 1:
                cls_tokens = self.cls_token.expand([B, -1, -1])
                x = paddle.concat([cls_tokens, x], dim=1)
            x = x + self.pos_embeds[i]
            x = self.pos_drops[i](x)
            for blk in self.blocks[i]:
                x = blk(x, H, W)
            if i < len(self.depths) - 1:
                x = x.reshape([B, H, W, -1]).transpose(
                    [0, 3, 1, 2]).contiguous()
        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


# PEG  from https://arxiv.org/abs/2102.10882
class PosCNN(nn.Layer):
    def __init__(self, in_chans, embed_dim=768, s=1):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2D(
                in_chans,
                embed_dim,
                3,
                s,
                1,
                bias_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)),
                groups=embed_dim,
                weight_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)), ))
        self.s = s

    def forward(self, x, H, W):
        B, N, C = x.shape
        feat_token = x
        cnn_feat = feat_token.transpose([0, 2, 1]).reshape([B, C, H, W])
        if self.s == 1:
            x = self.proj(cnn_feat) + cnn_feat
        else:
            x = self.proj(cnn_feat)
        x = x.flatten(2).transpose([0, 2, 1])
        return x


class CPVTV2(PyramidVisionTransformer):
    """
    Use useful results from CPVT. PEG and GAP.
    Therefore, cls token is no longer required.
    PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution
    changes during the training (such as segmentation, detection)
    """

    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
C
cuicheng01 已提交
435
                 class_num=1000,
G
gaotingquan 已提交
436 437 438 439 440 441 442 443 444 445 446 447
                 embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8],
                 mlp_ratios=[4, 4, 4, 4],
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3],
                 sr_ratios=[8, 4, 2, 1],
                 block_cls=Block):
C
cuicheng01 已提交
448 449 450 451
        super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
                         num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
                         attn_drop_rate, drop_path_rate, norm_layer, depths,
                         sr_ratios, block_cls)
G
gaotingquan 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
        del self.pos_embeds
        del self.cls_token
        self.pos_block = nn.LayerList(
            [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims])
        self.apply(self._init_weights)

    def _init_weights(self, m):
        import math
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)
        elif isinstance(m, nn.Conv2D):
            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
            fan_out //= m._groups
            normal_(0, math.sqrt(2.0 / fan_out))(m.weight)
            if m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2D):
            m.weight.data.fill_(1.0)
            m.bias.data.zero_()

    def forward_features(self, x):
        B = x.shape[0]

        for i in range(len(self.depths)):
            x, (H, W) = self.patch_embeds[i](x)
            x = self.pos_drops[i](x)

            for j, blk in enumerate(self.blocks[i]):
                x = blk(x, H, W)
                if j == 0:
                    x = self.pos_block[i](x, H, W)  # PEG here

            if i < len(self.depths) - 1:
T
Tingquan Gao 已提交
490
                x = x.reshape([B, H, W, x.shape[-1]]).transpose([0, 3, 1, 2])
G
gaotingquan 已提交
491 492 493 494 495 496 497 498 499 500

        x = self.norm(x)
        return x.mean(axis=1)  # GAP here


class PCPVT(CPVTV2):
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
C
cuicheng01 已提交
501
                 class_num=1000,
G
gaotingquan 已提交
502 503 504 505 506 507 508 509 510 511 512 513
                 embed_dims=[64, 128, 256],
                 num_heads=[1, 2, 4],
                 mlp_ratios=[4, 4, 4],
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 depths=[4, 4, 4],
                 sr_ratios=[4, 2, 1],
                 block_cls=SBlock):
C
cuicheng01 已提交
514 515 516 517
        super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
                         num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
                         attn_drop_rate, drop_path_rate, norm_layer, depths,
                         sr_ratios, block_cls)
G
gaotingquan 已提交
518 519 520 521 522 523 524 525 526 527 528


class ALTGVT(PCPVT):
    """
    alias Twins-SVT
    """

    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
littletomatodonkey's avatar
littletomatodonkey 已提交
529
                 class_num=1000,
G
gaotingquan 已提交
530 531 532 533 534 535 536 537 538 539 540 541 542
                 embed_dims=[64, 128, 256],
                 num_heads=[1, 2, 4],
                 mlp_ratios=[4, 4, 4],
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 depths=[4, 4, 4],
                 sr_ratios=[4, 2, 1],
                 block_cls=GroupBlock,
                 wss=[7, 7, 7]):
littletomatodonkey's avatar
littletomatodonkey 已提交
543
        super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
G
gaotingquan 已提交
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
                         num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
                         attn_drop_rate, drop_path_rate, norm_layer, depths,
                         sr_ratios, block_cls)
        del self.blocks
        self.wss = wss
        # transformer encoder
        dpr = [
            x.numpy()[0]
            for x in paddle.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule
        cur = 0
        self.blocks = nn.LayerList()
        for k in range(len(depths)):
            _block = nn.LayerList([
                block_cls(
                    dim=embed_dims[k],
                    num_heads=num_heads[k],
                    mlp_ratio=mlp_ratios[k],
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[cur + i],
                    norm_layer=norm_layer,
                    sr_ratio=sr_ratios[k],
                    ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])
            ])
            self.blocks.append(_block)
            cur += depths[k]
        self.apply(self._init_weights)

littletomatodonkey's avatar
littletomatodonkey 已提交
575

C
cuicheng01 已提交
576 577 578 579 580 581 582 583 584 585 586
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
G
gaotingquan 已提交
587

C
cuicheng01 已提交
588 589

def pcpvt_small(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
590 591 592 593 594 595 596 597 598 599 600
    model = CPVTV2(
        patch_size=4,
        embed_dims=[64, 128, 320, 512],
        num_heads=[1, 2, 5, 8],
        mlp_ratios=[8, 8, 4, 4],
        qkv_bias=True,
        norm_layer=partial(
            nn.LayerNorm, epsilon=1e-6),
        depths=[3, 4, 6, 3],
        sr_ratios=[8, 4, 2, 1],
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
601 602
    _load_pretrained(
        pretrained, model, MODEL_URLS["pcpvt_small"], use_ssld=use_ssld)
G
gaotingquan 已提交
603 604 605
    return model


C
cuicheng01 已提交
606
def pcpvt_base(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
607 608 609 610 611 612 613 614 615 616 617
    model = CPVTV2(
        patch_size=4,
        embed_dims=[64, 128, 320, 512],
        num_heads=[1, 2, 5, 8],
        mlp_ratios=[8, 8, 4, 4],
        qkv_bias=True,
        norm_layer=partial(
            nn.LayerNorm, epsilon=1e-6),
        depths=[3, 4, 18, 3],
        sr_ratios=[8, 4, 2, 1],
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
618 619
    _load_pretrained(
        pretrained, model, MODEL_URLS["pcpvt_base"], use_ssld=use_ssld)
G
gaotingquan 已提交
620 621 622
    return model


C
cuicheng01 已提交
623
def pcpvt_large(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
624 625 626 627 628 629 630 631 632 633 634
    model = CPVTV2(
        patch_size=4,
        embed_dims=[64, 128, 320, 512],
        num_heads=[1, 2, 5, 8],
        mlp_ratios=[8, 8, 4, 4],
        qkv_bias=True,
        norm_layer=partial(
            nn.LayerNorm, epsilon=1e-6),
        depths=[3, 8, 27, 3],
        sr_ratios=[8, 4, 2, 1],
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
635 636
    _load_pretrained(
        pretrained, model, MODEL_URLS["pcpvt_large"], use_ssld=use_ssld)
G
gaotingquan 已提交
637 638 639
    return model


C
cuicheng01 已提交
640
def alt_gvt_small(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
641 642 643 644 645 646 647 648 649 650 651 652
    model = ALTGVT(
        patch_size=4,
        embed_dims=[64, 128, 256, 512],
        num_heads=[2, 4, 8, 16],
        mlp_ratios=[4, 4, 4, 4],
        qkv_bias=True,
        norm_layer=partial(
            nn.LayerNorm, epsilon=1e-6),
        depths=[2, 2, 10, 4],
        wss=[7, 7, 7, 7],
        sr_ratios=[8, 4, 2, 1],
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
653 654
    _load_pretrained(
        pretrained, model, MODEL_URLS["alt_gvt_small"], use_ssld=use_ssld)
G
gaotingquan 已提交
655 656 657
    return model


C
cuicheng01 已提交
658
def alt_gvt_base(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
659 660 661 662 663 664 665 666 667 668 669
    model = ALTGVT(
        patch_size=4,
        embed_dims=[96, 192, 384, 768],
        num_heads=[3, 6, 12, 24],
        mlp_ratios=[4, 4, 4, 4],
        qkv_bias=True,
        norm_layer=partial(
            nn.LayerNorm, epsilon=1e-6),
        depths=[2, 2, 18, 2],
        wss=[7, 7, 7, 7],
        sr_ratios=[8, 4, 2, 1],
C
cuicheng01 已提交
670
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
671 672
    _load_pretrained(
        pretrained, model, MODEL_URLS["alt_gvt_base"], use_ssld=use_ssld)
G
gaotingquan 已提交
673 674 675
    return model


C
cuicheng01 已提交
676
def alt_gvt_large(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
677 678 679 680 681 682 683 684 685 686 687 688
    model = ALTGVT(
        patch_size=4,
        embed_dims=[128, 256, 512, 1024],
        num_heads=[4, 8, 16, 32],
        mlp_ratios=[4, 4, 4, 4],
        qkv_bias=True,
        norm_layer=partial(
            nn.LayerNorm, epsilon=1e-6),
        depths=[2, 2, 18, 2],
        wss=[7, 7, 7, 7],
        sr_ratios=[8, 4, 2, 1],
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
689 690
    _load_pretrained(
        pretrained, model, MODEL_URLS["alt_gvt_large"], use_ssld=use_ssld)
G
gaotingquan 已提交
691
    return model