tnt.py 12.4 KB
Newer Older
jm_12138's avatar
jm_12138 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19 20 21 22 23 24 25 26 27
import math
import numpy as np

import paddle
import paddle.nn as nn

from paddle.nn.initializer import TruncatedNormal, Constant

from ppcls.arch.backbone.base.theseus_layer import Identity
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
    "TNT_small":
C
cuicheng01 已提交
28
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/TNT_small_pretrained.pdparams"
29 30
}

jm_12138's avatar
jm_12138 已提交
31 32
__all__ = MODEL_URLS.keys()

33 34 35 36 37 38 39 40 41 42 43 44 45 46
trunc_normal_ = TruncatedNormal(std=.02)
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)


def drop_path(x, drop_prob=0., training=False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
C
cuicheng01 已提交
47
    random_tensor = paddle.add(keep_prob,  paddle.rand(shape, dtype=x.dtype))
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class Mlp(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
66 67 68 69 70 71
    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
90 91 92 93 94 95 96
    def __init__(self,
                 dim,
                 hidden_dim,
                 num_heads=8,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.):
97 98 99 100 101
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        head_dim = hidden_dim // num_heads
        self.head_dim = head_dim
littletomatodonkey's avatar
littletomatodonkey 已提交
102
        self.scale = head_dim**-0.5
103 104 105 106 107 108 109 110 111

        self.qk = nn.Linear(dim, hidden_dim * 2, bias_attr=qkv_bias)
        self.v = nn.Linear(dim, dim, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
littletomatodonkey's avatar
littletomatodonkey 已提交
112 113 114
        qk = self.qk(x).reshape(
            (B, N, 2, self.num_heads, self.head_dim)).transpose(
                (2, 0, 3, 1, 4))
115 116

        q, k = qk[0], qk[1]
C
cuicheng01 已提交
117
        v = self.v(x).reshape((B, N, self.num_heads, x.shape[-1] // self.num_heads)).transpose(
littletomatodonkey's avatar
littletomatodonkey 已提交
118
            (0, 2, 1, 3))
119

C
cuicheng01 已提交
120
        attn = paddle.matmul(q, k.transpose((0, 1, 3, 2))) * self.scale
121 122 123
        attn = nn.functional.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

C
cuicheng01 已提交
124 125
        x = paddle.matmul(attn, v)
        x = x.transpose((0, 2, 1, 3)).reshape((B, N, x.shape[-1] * x.shape[-3]))
126 127 128 129 130 131
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
132 133 134 135 136 137 138 139 140 141 142 143
    def __init__(self,
                 dim,
                 in_dim,
                 num_pixel,
                 num_heads=12,
                 in_num_head=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
144 145 146 147 148
                 norm_layer=nn.LayerNorm):
        super().__init__()
        # Inner transformer
        self.norm_in = norm_layer(in_dim)
        self.attn_in = Attention(
littletomatodonkey's avatar
littletomatodonkey 已提交
149 150 151 152 153 154
            in_dim,
            in_dim,
            num_heads=in_num_head,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop)
155 156

        self.norm_mlp_in = norm_layer(in_dim)
littletomatodonkey's avatar
littletomatodonkey 已提交
157 158 159 160 161
        self.mlp_in = Mlp(in_features=in_dim,
                          hidden_features=int(in_dim * 4),
                          out_features=in_dim,
                          act_layer=act_layer,
                          drop=drop)
162 163 164 165 166 167

        self.norm1_proj = norm_layer(in_dim)
        self.proj = nn.Linear(in_dim * num_pixel, dim)
        # Outer transformer
        self.norm_out = norm_layer(dim)
        self.attn_out = Attention(
littletomatodonkey's avatar
littletomatodonkey 已提交
168 169 170 171 172 173
            dim,
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop)
174 175 176 177

        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()

        self.norm_mlp = norm_layer(dim)
littletomatodonkey's avatar
littletomatodonkey 已提交
178 179 180 181 182
        self.mlp = Mlp(in_features=dim,
                       hidden_features=int(dim * mlp_ratio),
                       out_features=dim,
                       act_layer=act_layer,
                       drop=drop)
183 184 185

    def forward(self, pixel_embed, patch_embed):
        # inner
C
cuicheng01 已提交
186 187 188 189
        pixel_embed = paddle.add(pixel_embed, self.drop_path(
            self.attn_in(self.norm_in(pixel_embed))))
        pixel_embed = paddle.add(pixel_embed, self.drop_path(
            self.mlp_in(self.norm_mlp_in(pixel_embed))))
190 191
        # outer
        B, N, C = patch_embed.shape
C
cuicheng01 已提交
192 193 194 195 196 197
        patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:], self.proj(
            self.norm1_proj(pixel_embed).reshape((B, N - 1, -1))))
        patch_embed = paddle.add(patch_embed, self.drop_path(
            self.attn_out(self.norm_out(patch_embed))))
        patch_embed = paddle.add(patch_embed, self.drop_path(
            self.mlp(self.norm_mlp(patch_embed))))
198 199 200 201
        return pixel_embed, patch_embed


class PixelEmbed(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
202 203 204 205 206 207
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 in_dim=48,
                 stride=4):
208
        super().__init__()
littletomatodonkey's avatar
littletomatodonkey 已提交
209
        num_patches = (img_size // patch_size)**2
210 211 212 213 214 215 216
        self.img_size = img_size
        self.num_patches = num_patches
        self.in_dim = in_dim
        new_patch_size = math.ceil(patch_size / stride)
        self.new_patch_size = new_patch_size

        self.proj = nn.Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
217
            in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
218 219 220 221 222 223 224

    def forward(self, x, pixel_pos):
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."

        x = self.proj(x)
        x = nn.functional.unfold(x, self.new_patch_size, self.new_patch_size)
littletomatodonkey's avatar
littletomatodonkey 已提交
225
        x = x.transpose((0, 2, 1)).reshape(
C
cuicheng01 已提交
226
            (-1, self.in_dim, self.new_patch_size, self.new_patch_size))
227
        x = x + pixel_pos
C
cuicheng01 已提交
228
        x = x.reshape((-1, self.in_dim, self.new_patch_size * self.new_patch_size)).transpose(
littletomatodonkey's avatar
littletomatodonkey 已提交
229
            (0, 2, 1))
230 231 232 233
        return x


class TNT(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 embed_dim=768,
                 in_dim=48,
                 depth=12,
                 num_heads=12,
                 in_num_head=4,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 first_stride=4,
                 class_num=1000):
251
        super().__init__()
littletomatodonkey's avatar
littletomatodonkey 已提交
252
        self.class_num = class_num
253 254 255 256
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim

        self.pixel_embed = PixelEmbed(
littletomatodonkey's avatar
littletomatodonkey 已提交
257 258 259 260 261
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            in_dim=in_dim,
            stride=first_stride)
262 263 264
        num_patches = self.pixel_embed.num_patches
        self.num_patches = num_patches
        new_patch_size = self.pixel_embed.new_patch_size
littletomatodonkey's avatar
littletomatodonkey 已提交
265
        num_pixel = new_patch_size**2
266 267 268 269 270 271

        self.norm1_proj = norm_layer(num_pixel * in_dim)
        self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
        self.norm2_proj = norm_layer(embed_dim)

        self.cls_token = self.create_parameter(
littletomatodonkey's avatar
littletomatodonkey 已提交
272
            shape=(1, 1, embed_dim), default_initializer=zeros_)
273 274 275
        self.add_parameter("cls_token", self.cls_token)

        self.patch_pos = self.create_parameter(
littletomatodonkey's avatar
littletomatodonkey 已提交
276
            shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_)
277 278 279
        self.add_parameter("patch_pos", self.patch_pos)

        self.pixel_pos = self.create_parameter(
littletomatodonkey's avatar
littletomatodonkey 已提交
280 281
            shape=(1, in_dim, new_patch_size, new_patch_size),
            default_initializer=zeros_)
282 283 284 285 286 287 288 289 290
        self.add_parameter("pixel_pos", self.pixel_pos)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, depth)

        blocks = []
        for i in range(depth):
littletomatodonkey's avatar
littletomatodonkey 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303
            blocks.append(
                Block(
                    dim=embed_dim,
                    in_dim=in_dim,
                    num_pixel=num_pixel,
                    num_heads=num_heads,
                    in_num_head=in_num_head,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer))
304 305 306
        self.blocks = nn.LayerList(blocks)
        self.norm = norm_layer(embed_dim)

littletomatodonkey's avatar
littletomatodonkey 已提交
307 308
        if class_num > 0:
            self.head = nn.Linear(embed_dim, class_num)
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327

        trunc_normal_(self.cls_token)
        trunc_normal_(self.patch_pos)
        trunc_normal_(self.pixel_pos)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward_features(self, x):
        B = x.shape[0]
        pixel_embed = self.pixel_embed(x, self.pixel_pos)

littletomatodonkey's avatar
littletomatodonkey 已提交
328 329 330
        patch_embed = self.norm2_proj(
            self.proj(
                self.norm1_proj(
C
cuicheng01 已提交
331
                    pixel_embed.reshape((-1, self.num_patches, pixel_embed.shape[-1] * pixel_embed.shape[-2])))))
littletomatodonkey's avatar
littletomatodonkey 已提交
332 333
        patch_embed = paddle.concat(
            (self.cls_token.expand((B, -1, -1)), patch_embed), axis=1)
334 335 336 337 338 339 340 341 342 343 344 345
        patch_embed = patch_embed + self.patch_pos
        patch_embed = self.pos_drop(patch_embed)

        for blk in self.blocks:
            pixel_embed, patch_embed = blk(pixel_embed, patch_embed)

        patch_embed = self.norm(patch_embed)
        return patch_embed[:, 0]

    def forward(self, x):
        x = self.forward_features(x)

littletomatodonkey's avatar
littletomatodonkey 已提交
346
        if self.class_num > 0:
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
            x = self.head(x)
        return x


def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )


def TNT_small(pretrained=False, **kwargs):
littletomatodonkey's avatar
littletomatodonkey 已提交
365 366 367 368 369 370 371 372
    model = TNT(patch_size=16,
                embed_dim=384,
                in_dim=24,
                depth=12,
                num_heads=6,
                in_num_head=4,
                qkv_bias=False,
                **kwargs)
373 374
    _load_pretrained(pretrained, model, MODEL_URLS["TNT_small"])
    return model