vit_mae.py 26.3 KB
Newer Older
Z
Zhao-Yian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
import math
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Constant, TruncatedNormal

from ppdet.modeling.shape_spec import ShapeSpec
from ppdet.core.workspace import register, serializable

from .transformer_utils import (zeros_, DropPath, Identity, window_partition,
                                window_unpartition)
from ..initializer import linear_init_

__all__ = ['VisionTransformer2D', 'SimpleFeaturePyramid']


class Mlp(nn.Layer):
    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer='nn.GELU',
                 drop=0.,
                 lr_factor=1.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(
            in_features,
            hidden_features,
            weight_attr=ParamAttr(learning_rate=lr_factor),
            bias_attr=ParamAttr(learning_rate=lr_factor))
        self.act = eval(act_layer)()
        self.fc2 = nn.Linear(
            hidden_features,
            out_features,
            weight_attr=ParamAttr(learning_rate=lr_factor),
            bias_attr=ParamAttr(learning_rate=lr_factor))
        self.drop = nn.Dropout(drop)

        self._init_weights()

    def _init_weights(self):
        linear_init_(self.fc1)
        linear_init_(self.fc2)

    def forward(self, x):
        x = self.drop(self.act(self.fc1(x)))
        x = self.drop(self.fc2(x))
        return x


class Attention(nn.Layer):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 attn_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 use_rel_pos=False,
                 rel_pos_zero_init=True,
                 window_size=None,
                 input_size=None,
                 qk_scale=None,
                 lr_factor=1.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim**-0.5
        self.use_rel_pos = use_rel_pos
        self.input_size = input_size
        self.rel_pos_zero_init = rel_pos_zero_init
        self.window_size = window_size
        self.lr_factor = lr_factor

        self.qkv = nn.Linear(
            dim,
            dim * 3,
            weight_attr=ParamAttr(learning_rate=lr_factor),
            bias_attr=ParamAttr(learning_rate=lr_factor)
            if attn_bias else False)
        if qkv_bias:
            self.q_bias = self.create_parameter(
                shape=([dim]), default_initializer=zeros_)
            self.v_bias = self.create_parameter(
                shape=([dim]), default_initializer=zeros_)
        else:
            self.q_bias = None
            self.v_bias = None
        self.proj = nn.Linear(
            dim,
            dim,
            weight_attr=ParamAttr(learning_rate=lr_factor),
            bias_attr=ParamAttr(learning_rate=lr_factor))
        self.attn_drop = nn.Dropout(attn_drop)
        if window_size is None:
            self.window_size = self.input_size[0]

        self._init_weights()

    def _init_weights(self):
        linear_init_(self.qkv)
        linear_init_(self.proj)

        if self.use_rel_pos:
            self.rel_pos_h = self.create_parameter(
                [2 * self.window_size - 1, self.head_dim],
                attr=ParamAttr(learning_rate=self.lr_factor),
                default_initializer=Constant(value=0.))
            self.rel_pos_w = self.create_parameter(
                [2 * self.window_size - 1, self.head_dim],
                attr=ParamAttr(learning_rate=self.lr_factor),
                default_initializer=Constant(value=0.))

            if not self.rel_pos_zero_init:
                TruncatedNormal(self.rel_pos_h, std=0.02)
                TruncatedNormal(self.rel_pos_w, std=0.02)

    def get_rel_pos(self, seq_size, rel_pos):
        max_rel_dist = int(2 * seq_size - 1)
        # Interpolate rel pos if needed.
        if rel_pos.shape[0] != max_rel_dist:
            # Interpolate rel pos.
            rel_pos = rel_pos.reshape([1, rel_pos.shape[0], -1])
            rel_pos = rel_pos.transpose([0, 2, 1])
            rel_pos_resized = F.interpolate(
                rel_pos,
                size=(max_rel_dist, ),
                mode="linear",
                data_format='NCW')
            rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist])
            rel_pos_resized = rel_pos_resized.transpose([1, 0])
        else:
            rel_pos_resized = rel_pos

        coords = paddle.arange(seq_size, dtype='float32')
        relative_coords = coords.unsqueeze(-1) - coords.unsqueeze(0)
        relative_coords += (seq_size - 1)
        relative_coords = relative_coords.astype('int64').flatten()

        return paddle.index_select(rel_pos_resized, relative_coords).reshape(
            [seq_size, seq_size, self.head_dim])

    def add_decomposed_rel_pos(self, attn, q, h, w):
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        Args:
            attn (Tensor): attention map.
            q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        Returns:
            attn (Tensor): attention map with added relative positional embeddings.
        """
        Rh = self.get_rel_pos(h, self.rel_pos_h)
        Rw = self.get_rel_pos(w, self.rel_pos_w)

        B, _, dim = q.shape
        r_q = q.reshape([B, h, w, dim])
        # bhwc, hch->bhwh1
        # bwhc, wcw->bhw1w
        rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh).unsqueeze(-1)
        rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw).unsqueeze(-2)

        attn = attn.reshape([B, h, w, h, w]) + rel_h + rel_w
        return attn.reshape([B, h * w, h * w])

    def forward(self, x):
        B, H, W, C = paddle.shape(x)

        if self.q_bias is not None:
            qkv_bias = paddle.concat(
                (self.q_bias, paddle.zeros_like(self.v_bias), self.v_bias))
            qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
        else:
            qkv = self.qkv(x).reshape(
                [B, H * W, 3, self.num_heads, self.head_dim]).transpose(
                    [2, 0, 3, 1, 4]).reshape(
                        [3, B * self.num_heads, H * W, self.head_dim])

        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = q.matmul(k.transpose([0, 2, 1])) * self.scale

        if self.use_rel_pos:
            attn = self.add_decomposed_rel_pos(attn, q, H, W)

        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)
        x = attn.matmul(v).reshape(
            [B, self.num_heads, H * W, self.head_dim]).transpose(
                [0, 2, 1, 3]).reshape([B, H, W, C])
        x = self.proj(x)
        return x


class Block(nn.Layer):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 attn_bias=False,
                 qk_scale=None,
                 init_values=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 use_rel_pos=True,
                 rel_pos_zero_init=True,
                 window_size=None,
                 input_size=None,
                 act_layer='nn.GELU',
                 norm_layer='nn.LayerNorm',
                 lr_factor=1.0,
                 epsilon=1e-5):
        super().__init__()
        self.window_size = window_size

        self.norm1 = eval(norm_layer)(dim,
                                      weight_attr=ParamAttr(
                                          learning_rate=lr_factor,
                                          regularizer=L2Decay(0.0)),
                                      bias_attr=ParamAttr(
                                          learning_rate=lr_factor,
                                          regularizer=L2Decay(0.0)),
                                      epsilon=epsilon)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_bias=attn_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            window_size=window_size,
            input_size=input_size,
            lr_factor=lr_factor)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
        self.norm2 = eval(norm_layer)(dim,
                                      weight_attr=ParamAttr(
                                          learning_rate=lr_factor,
                                          regularizer=L2Decay(0.0)),
                                      bias_attr=ParamAttr(
                                          learning_rate=lr_factor,
                                          regularizer=L2Decay(0.0)),
                                      epsilon=epsilon)
        self.mlp = Mlp(in_features=dim,
                       hidden_features=int(dim * mlp_ratio),
                       act_layer=act_layer,
                       drop=drop,
                       lr_factor=lr_factor)
        if init_values is not None:
            self.gamma_1 = self.create_parameter(
                shape=([dim]), default_initializer=Constant(value=init_values))
            self.gamma_2 = self.create_parameter(
                shape=([dim]), default_initializer=Constant(value=init_values))
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x):
        y = self.norm1(x)
        if self.window_size is not None:
            y, pad_hw, num_hw = window_partition(y, self.window_size)
        y = self.attn(y)
        if self.gamma_1 is not None:
            y = self.gamma_1 * y

        if self.window_size is not None:
            y = window_unpartition(y, pad_hw, num_hw, (x.shape[1], x.shape[2]))
        x = x + self.drop_path(y)
        if self.gamma_2 is None:
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))

        return x


class PatchEmbed(nn.Layer):
    """ Image to Patch Embedding
    """

    def __init__(self,
                 img_size=(224, 224),
                 patch_size=16,
                 in_chans=3,
                 embed_dim=768,
                 lr_factor=0.01):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.proj = nn.Conv2D(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
            weight_attr=ParamAttr(learning_rate=lr_factor),
            bias_attr=ParamAttr(learning_rate=lr_factor))

    @property
    def num_patches_in_h(self):
        return self.img_size[1] // self.patch_size

    @property
    def num_patches_in_w(self):
        return self.img_size[0] // self.patch_size

    def forward(self, x):
        out = self.proj(x)
        return out


@register
@serializable
class VisionTransformer2D(nn.Layer):
    """ Vision Transformer with support for patch input
    """

    def __init__(self,
                 img_size=(1024, 1024),
                 patch_size=16,
                 in_chans=3,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4,
                 qkv_bias=False,
                 attn_bias=False,
                 qk_scale=None,
                 init_values=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 act_layer='nn.GELU',
                 norm_layer='nn.LayerNorm',
                 lr_decay_rate=1.0,
                 global_attn_indexes=(2, 5, 8, 11),
                 use_abs_pos=False,
                 use_rel_pos=False,
                 use_abs_pos_emb=False,
                 use_sincos_pos_emb=False,
                 rel_pos_zero_init=True,
                 epsilon=1e-5,
                 final_norm=False,
                 pretrained=None,
                 window_size=None,
                 out_indices=(11, ),
                 with_fpn=False,
                 use_checkpoint=False,
                 *args,
                 **kwargs):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth
        self.global_attn_indexes = global_attn_indexes
        self.epsilon = epsilon
        self.with_fpn = with_fpn
        self.use_checkpoint = use_checkpoint

        self.patch_h = img_size[0] // patch_size
        self.patch_w = img_size[1] // patch_size
        self.num_patches = self.patch_h * self.patch_w
        self.use_abs_pos = use_abs_pos
        self.use_abs_pos_emb = use_abs_pos_emb

        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim)

        dpr = np.linspace(0, drop_path_rate, depth)
        if use_checkpoint:
            paddle.seed(0)

        if use_abs_pos_emb:
            self.pos_w = self.patch_embed.num_patches_in_w
            self.pos_h = self.patch_embed.num_patches_in_h
            self.pos_embed = self.create_parameter(
                shape=(1, self.pos_w * self.pos_h + 1, embed_dim),
                default_initializer=paddle.nn.initializer.TruncatedNormal(
                    std=.02))
        elif use_sincos_pos_emb:
            pos_embed = self.get_2d_sincos_position_embedding(self.patch_h,
                                                              self.patch_w)

            self.pos_embed = pos_embed
            self.pos_embed = self.create_parameter(shape=pos_embed.shape)
            self.pos_embed.set_value(pos_embed.numpy())
            self.pos_embed.stop_gradient = True
        else:
            self.pos_embed = None

        self.blocks = nn.LayerList([
            Block(
                embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                attn_bias=attn_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                use_rel_pos=use_rel_pos,
                rel_pos_zero_init=rel_pos_zero_init,
                window_size=None
                if i in self.global_attn_indexes else window_size,
                input_size=[self.patch_h, self.patch_w],
                act_layer=act_layer,
                lr_factor=self.get_vit_lr_decay_rate(i, lr_decay_rate),
                norm_layer=norm_layer,
                init_values=init_values,
                epsilon=epsilon) for i in range(depth)
        ])

        assert len(out_indices) <= 4, 'out_indices out of bound'
        self.out_indices = out_indices
        self.pretrained = pretrained
        self.init_weight()

        self.out_channels = [embed_dim for _ in range(len(out_indices))]
        self.out_strides = [4, 8, 16, 32][-len(out_indices):] if with_fpn else [
            patch_size for _ in range(len(out_indices))
        ]
        self.norm = Identity()
        if self.with_fpn:
            self.init_fpn(
                embed_dim=embed_dim,
                patch_size=patch_size,
                out_with_norm=final_norm)

    def get_vit_lr_decay_rate(self, layer_id, lr_decay_rate):
        return lr_decay_rate**(self.depth - layer_id)

    def init_weight(self):
        pretrained = self.pretrained
        if pretrained:
            if 'http' in pretrained:
                path = paddle.utils.download.get_weights_path_from_url(
                    pretrained)
            else:
                path = pretrained

            load_state_dict = paddle.load(path)
            model_state_dict = self.state_dict()
            pos_embed_name = "pos_embed"

            if pos_embed_name in load_state_dict.keys(
            ) and self.use_abs_pos_emb:
                load_pos_embed = paddle.to_tensor(
                    load_state_dict[pos_embed_name], dtype="float32")
                if self.pos_embed.shape != load_pos_embed.shape:
                    pos_size = int(math.sqrt(load_pos_embed.shape[1] - 1))
                    model_state_dict[pos_embed_name] = self.resize_pos_embed(
                        load_pos_embed, (pos_size, pos_size),
                        (self.pos_h, self.pos_w))

                    # self.set_state_dict(model_state_dict)
                    load_state_dict[pos_embed_name] = model_state_dict[
                        pos_embed_name]

                    print("Load pos_embed and resize it from {} to {} .".format(
                        load_pos_embed.shape, self.pos_embed.shape))

            self.set_state_dict(load_state_dict)
            print("Load load_state_dict....")

    def init_fpn(self, embed_dim=768, patch_size=16, out_with_norm=False):
        if patch_size == 16:
            self.fpn1 = nn.Sequential(
                nn.Conv2DTranspose(
                    embed_dim, embed_dim, kernel_size=2, stride=2),
                nn.BatchNorm2D(embed_dim),
                nn.GELU(),
                nn.Conv2DTranspose(
                    embed_dim, embed_dim, kernel_size=2, stride=2), )

            self.fpn2 = nn.Sequential(
                nn.Conv2DTranspose(
                    embed_dim, embed_dim, kernel_size=2, stride=2), )

            self.fpn3 = Identity()

            self.fpn4 = nn.MaxPool2D(kernel_size=2, stride=2)
        elif patch_size == 8:
            self.fpn1 = nn.Sequential(
                nn.Conv2DTranspose(
                    embed_dim, embed_dim, kernel_size=2, stride=2), )

            self.fpn2 = Identity()

            self.fpn3 = nn.Sequential(nn.MaxPool2D(kernel_size=2, stride=2), )

            self.fpn4 = nn.Sequential(nn.MaxPool2D(kernel_size=4, stride=4), )

        if not out_with_norm:
            self.norm = Identity()
        else:
            self.norm = nn.LayerNorm(embed_dim, epsilon=self.epsilon)

    def resize_pos_embed(self, pos_embed, old_hw, new_hw):
        """
        Resize pos_embed weight.
        Args:
            pos_embed (Tensor): the pos_embed weight
            old_hw (list[int]): the height and width of old pos_embed
            new_hw (list[int]): the height and width of new pos_embed
        Returns:
            Tensor: the resized pos_embed weight
        """
        cls_pos_embed = pos_embed[:, :1, :]
        pos_embed = pos_embed[:, 1:, :]

        pos_embed = pos_embed.transpose([0, 2, 1])
        pos_embed = pos_embed.reshape([1, -1, old_hw[0], old_hw[1]])
        pos_embed = F.interpolate(
            pos_embed, new_hw, mode='bicubic', align_corners=False)
        pos_embed = pos_embed.flatten(2).transpose([0, 2, 1])
        pos_embed = paddle.concat([cls_pos_embed, pos_embed], axis=1)

        return pos_embed

    def get_2d_sincos_position_embedding(self, h, w, temperature=10000.):
        grid_y, grid_x = paddle.meshgrid(
            paddle.arange(
                h, dtype=paddle.float32),
            paddle.arange(
                w, dtype=paddle.float32))
        assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
        pos_dim = self.embed_dim // 4
        omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
        omega = (1. / (temperature**omega)).unsqueeze(0)

        out_x = grid_x.reshape([-1, 1]).matmul(omega)
        out_y = grid_y.reshape([-1, 1]).matmul(omega)

        pos_emb = paddle.concat(
            [
                paddle.sin(out_y), paddle.cos(out_y), paddle.sin(out_x),
                paddle.cos(out_x)
            ],
            axis=1)

        return pos_emb.reshape([1, h, w, self.embed_dim])

    def forward(self, inputs):
        x = self.patch_embed(inputs['image']).transpose([0, 2, 3, 1])
        B, Hp, Wp, _ = paddle.shape(x)

        if self.use_abs_pos:
            x = x + self.get_2d_sincos_position_embedding(Hp, Wp)

        if self.use_abs_pos_emb:
            x = x + self.resize_pos_embed(self.pos_embed,
                                          (self.pos_h, self.pos_w), (Hp, Wp))

        feats = []
        for idx, blk in enumerate(self.blocks):
            if self.use_checkpoint and self.training:
                x = paddle.distributed.fleet.utils.recompute(
                    blk, x, **{"preserve_rng_state": True})
            else:
                x = blk(x)
            if idx in self.out_indices:
                feats.append(self.norm(x.transpose([0, 3, 1, 2])))

        if self.with_fpn:
            fpns = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
            for i in range(len(feats)):
                feats[i] = fpns[i](feats[i])
        return feats

    @property
    def num_layers(self):
        return len(self.blocks)

    @property
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    @property
    def out_shape(self):
        return [
            ShapeSpec(
                channels=c, stride=s)
            for c, s in zip(self.out_channels, self.out_strides)
        ]


class LayerNorm(nn.Layer):
    """
    A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
    variance normalization over the channel dimension for inputs that have shape
    (batch_size, channels, height, width).    
    Note that, the modified LayerNorm on used in ResBlock and SimpleFeaturePyramid.

    In ViT, we use the nn.LayerNorm
    """

    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__()
        self.weight = self.create_parameter([normalized_shape])
        self.bias = self.create_parameter([normalized_shape])
        self.eps = eps
        self.normalized_shape = (normalized_shape, )

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / paddle.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


@register
@serializable
class SimpleFeaturePyramid(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 spatial_scales,
                 num_levels=4,
                 use_bias=False):
        """
        Args:
            in_channels (list[int]): input channels of each level which can be 
                derived from the output shape of backbone by from_config
            out_channel (int): output channel of each level.
            spatial_scales (list[float]): list of scaling factors to upsample or downsample
                the input features for creating pyramid features which can be derived from 
                the output shape of backbone by from_config
            num_levels (int): number of levels of output features.
            use_bias (bool): whether use bias or not.
        """
        super(SimpleFeaturePyramid, self).__init__()

        self.in_channels = in_channels[0]
        self.out_channels = out_channels
        self.num_levels = num_levels

        self.stages = []
        dim = self.in_channels
        if num_levels == 4:
            scale_factors = [2.0, 1.0, 0.5]
        elif num_levels == 5:
            scale_factors = [4.0, 2.0, 1.0, 0.5]
        else:
            raise NotImplementedError(
                f"num_levels={num_levels} is not supported yet.")

        dim = in_channels[0]
        for idx, scale in enumerate(scale_factors):
            out_dim = dim
            if scale == 4.0:
                layers = [
                    nn.Conv2DTranspose(
                        dim, dim // 2, kernel_size=2, stride=2),
                    nn.LayerNorm(dim // 2),
                    nn.GELU(),
                    nn.Conv2DTranspose(
                        dim // 2, dim // 4, kernel_size=2, stride=2),
                ]
                out_dim = dim // 4
            elif scale == 2.0:
                layers = [
                    nn.Conv2DTranspose(
                        dim, dim // 2, kernel_size=2, stride=2)
                ]
                out_dim = dim // 2
            elif scale == 1.0:
                layers = []
            elif scale == 0.5:
                layers = [nn.MaxPool2D(kernel_size=2, stride=2)]

            layers.extend([
                nn.Conv2D(
                    out_dim,
                    out_channels,
                    kernel_size=1,
                    bias_attr=use_bias, ), LayerNorm(out_channels), nn.Conv2D(
                        out_channels,
                        out_channels,
                        kernel_size=3,
                        padding=1,
                        bias_attr=use_bias, ), LayerNorm(out_channels)
            ])
            layers = nn.Sequential(*layers)

            stage = -int(math.log2(spatial_scales[0] * scale_factors[idx]))
            self.add_sublayer(f"simfp_{stage}", layers)
            self.stages.append(layers)

        # top block output feature maps.
        self.top_block = nn.Sequential(
            nn.MaxPool2D(
                kernel_size=1, stride=2, padding=0))

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {
            'in_channels': [i.channels for i in input_shape],
            'spatial_scales': [1.0 / i.stride for i in input_shape],
        }

    @property
    def out_shape(self):
        return [
            ShapeSpec(channels=self.out_channels)
            for _ in range(self.num_levels)
        ]

    def forward(self, feats):
        """
        Args:
            x: Tensor of shape (N,C,H,W).
        """
        features = feats[0]
        results = []

        for stage in self.stages:
            results.append(stage(features))

        top_block_in_feature = results[-1]
        results.append(self.top_block(top_block_in_feature))
        assert self.num_levels == len(results)

        return results