distilled_vision_transformer.py 8.2 KB
Newer Older
jm_12138's avatar
jm_12138 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
cuicheng01 已提交
15
# Code was heavily based on https://github.com/facebookresearch/deit
G
gaotingquan 已提交
16
# reference: https://arxiv.org/abs/2012.12877
C
cuicheng01 已提交
17

jm_12138's avatar
jm_12138 已提交
18 19 20 21
import paddle
import paddle.nn as nn
from .vision_transformer import VisionTransformer, Identity, trunc_normal_, zeros_

C
cuicheng01 已提交
22 23 24
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
littletomatodonkey's avatar
littletomatodonkey 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    "DeiT_tiny_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_tiny_patch16_224_pretrained.pdparams",
    "DeiT_small_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_small_patch16_224_pretrained.pdparams",
    "DeiT_base_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_patch16_224_pretrained.pdparams",
    "DeiT_tiny_distilled_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_tiny_distilled_patch16_224_pretrained.pdparams",
    "DeiT_small_distilled_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_small_distilled_patch16_224_pretrained.pdparams",
    "DeiT_base_distilled_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_distilled_patch16_224_pretrained.pdparams",
    "DeiT_base_patch16_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_patch16_384_pretrained.pdparams",
    "DeiT_base_distilled_patch16_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_distilled_patch16_384_pretrained.pdparams",
}
C
cuicheng01 已提交
42 43

__all__ = list(MODEL_URLS.keys())
jm_12138's avatar
jm_12138 已提交
44 45 46


class DistilledVisionTransformer(VisionTransformer):
47 48 49
    def __init__(self,
                 img_size=224,
                 patch_size=16,
littletomatodonkey's avatar
littletomatodonkey 已提交
50
                 class_num=1000,
51 52 53 54 55 56 57
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4,
                 qkv_bias=False,
                 norm_layer='nn.LayerNorm',
                 epsilon=1e-5,
jm_12138's avatar
jm_12138 已提交
58
                 **kwargs):
59 60 61
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
littletomatodonkey's avatar
littletomatodonkey 已提交
62
            class_num=class_num,
63 64 65 66 67 68 69 70
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            epsilon=epsilon,
            **kwargs)
jm_12138's avatar
jm_12138 已提交
71
        self.pos_embed = self.create_parameter(
72 73
            shape=(1, self.patch_embed.num_patches + 2, self.embed_dim),
            default_initializer=zeros_)
jm_12138's avatar
jm_12138 已提交
74 75 76 77 78 79 80
        self.add_parameter("pos_embed", self.pos_embed)

        self.dist_token = self.create_parameter(
            shape=(1, 1, self.embed_dim), default_initializer=zeros_)
        self.add_parameter("cls_token", self.cls_token)

        self.head_dist = nn.Linear(
81
            self.embed_dim,
littletomatodonkey's avatar
littletomatodonkey 已提交
82
            self.class_num) if self.class_num > 0 else Identity()
jm_12138's avatar
jm_12138 已提交
83 84 85 86 87 88

        trunc_normal_(self.dist_token)
        trunc_normal_(self.pos_embed)
        self.head_dist.apply(self._init_weights)

    def forward_features(self, x):
89
        B = paddle.shape(x)[0]
jm_12138's avatar
jm_12138 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand((B, -1, -1))
        dist_token = self.dist_token.expand((B, -1, -1))
        x = paddle.concat((cls_tokens, dist_token, x), axis=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0], x[:, 1]

    def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        return (x + x_dist) / 2


C
cuicheng01 已提交
112 113 114 115 116 117 118 119 120 121 122
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
littletomatodonkey's avatar
littletomatodonkey 已提交
123

C
cuicheng01 已提交
124 125

def DeiT_tiny_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
126
    model = VisionTransformer(
127 128 129 130 131 132 133 134
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
135 136 137 138 139
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_tiny_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
140 141 142
    return model


C
cuicheng01 已提交
143
def DeiT_small_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
144
    model = VisionTransformer(
145 146 147 148 149 150 151 152
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
153 154 155 156 157
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_small_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
158 159 160
    return model


C
cuicheng01 已提交
161
def DeiT_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
162
    model = VisionTransformer(
163 164 165 166 167 168 169 170
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
171 172 173 174 175
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
176 177 178
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
179 180
def DeiT_tiny_distilled_patch16_224(pretrained=False, use_ssld=False,
                                    **kwargs):
jm_12138's avatar
jm_12138 已提交
181
    model = DistilledVisionTransformer(
182 183 184 185 186 187 188 189
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
190 191 192 193 194
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_tiny_distilled_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
195 196 197
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
198 199 200
def DeiT_small_distilled_patch16_224(pretrained=False,
                                     use_ssld=False,
                                     **kwargs):
jm_12138's avatar
jm_12138 已提交
201
    model = DistilledVisionTransformer(
202 203 204 205 206 207 208 209
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
210 211 212 213 214
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_small_distilled_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
215 216 217
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
218 219
def DeiT_base_distilled_patch16_224(pretrained=False, use_ssld=False,
                                    **kwargs):
jm_12138's avatar
jm_12138 已提交
220
    model = DistilledVisionTransformer(
221 222 223 224 225 226 227 228
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
229 230 231 232 233
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_distilled_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
234 235 236
    return model


C
cuicheng01 已提交
237
def DeiT_base_patch16_384(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
238
    model = VisionTransformer(
239 240 241 242 243 244 245 246 247
        img_size=384,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
248 249 250 251 252
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_patch16_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
253 254 255
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
256 257
def DeiT_base_distilled_patch16_384(pretrained=False, use_ssld=False,
                                    **kwargs):
jm_12138's avatar
jm_12138 已提交
258
    model = DistilledVisionTransformer(
259 260 261 262 263 264 265 266 267
        img_size=384,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
268 269 270 271 272
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_distilled_patch16_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
273
    return model