distilled_vision_transformer.py 8.1 KB
Newer Older
jm_12138's avatar
jm_12138 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
cuicheng01 已提交
15 16
# Code was heavily based on https://github.com/facebookresearch/deit

jm_12138's avatar
jm_12138 已提交
17 18 19 20
import paddle
import paddle.nn as nn
from .vision_transformer import VisionTransformer, Identity, trunc_normal_, zeros_

C
cuicheng01 已提交
21 22 23
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
littletomatodonkey's avatar
littletomatodonkey 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    "DeiT_tiny_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_tiny_patch16_224_pretrained.pdparams",
    "DeiT_small_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_small_patch16_224_pretrained.pdparams",
    "DeiT_base_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_patch16_224_pretrained.pdparams",
    "DeiT_tiny_distilled_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_tiny_distilled_patch16_224_pretrained.pdparams",
    "DeiT_small_distilled_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_small_distilled_patch16_224_pretrained.pdparams",
    "DeiT_base_distilled_patch16_224":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_distilled_patch16_224_pretrained.pdparams",
    "DeiT_base_patch16_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_patch16_384_pretrained.pdparams",
    "DeiT_base_distilled_patch16_384":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_distilled_patch16_384_pretrained.pdparams",
}
C
cuicheng01 已提交
41 42

__all__ = list(MODEL_URLS.keys())
jm_12138's avatar
jm_12138 已提交
43 44 45


class DistilledVisionTransformer(VisionTransformer):
46 47 48
    def __init__(self,
                 img_size=224,
                 patch_size=16,
littletomatodonkey's avatar
littletomatodonkey 已提交
49
                 class_num=1000,
50 51 52 53 54 55 56
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4,
                 qkv_bias=False,
                 norm_layer='nn.LayerNorm',
                 epsilon=1e-5,
jm_12138's avatar
jm_12138 已提交
57
                 **kwargs):
58 59 60
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
littletomatodonkey's avatar
littletomatodonkey 已提交
61
            class_num=class_num,
62 63 64 65 66 67 68 69
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            epsilon=epsilon,
            **kwargs)
jm_12138's avatar
jm_12138 已提交
70
        self.pos_embed = self.create_parameter(
71 72
            shape=(1, self.patch_embed.num_patches + 2, self.embed_dim),
            default_initializer=zeros_)
jm_12138's avatar
jm_12138 已提交
73 74 75 76 77 78 79
        self.add_parameter("pos_embed", self.pos_embed)

        self.dist_token = self.create_parameter(
            shape=(1, 1, self.embed_dim), default_initializer=zeros_)
        self.add_parameter("cls_token", self.cls_token)

        self.head_dist = nn.Linear(
80
            self.embed_dim,
littletomatodonkey's avatar
littletomatodonkey 已提交
81
            self.class_num) if self.class_num > 0 else Identity()
jm_12138's avatar
jm_12138 已提交
82 83 84 85 86 87

        trunc_normal_(self.dist_token)
        trunc_normal_(self.pos_embed)
        self.head_dist.apply(self._init_weights)

    def forward_features(self, x):
88
        B = paddle.shape(x)[0]
jm_12138's avatar
jm_12138 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand((B, -1, -1))
        dist_token = self.dist_token.expand((B, -1, -1))
        x = paddle.concat((cls_tokens, dist_token, x), axis=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0], x[:, 1]

    def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        return (x + x_dist) / 2


C
cuicheng01 已提交
111 112 113 114 115 116 117 118 119 120 121
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
littletomatodonkey's avatar
littletomatodonkey 已提交
122

C
cuicheng01 已提交
123 124

def DeiT_tiny_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
125
    model = VisionTransformer(
126 127 128 129 130 131 132 133
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
134 135 136 137 138
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_tiny_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
139 140 141
    return model


C
cuicheng01 已提交
142
def DeiT_small_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
143
    model = VisionTransformer(
144 145 146 147 148 149 150 151
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
152 153 154 155 156
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_small_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
157 158 159
    return model


C
cuicheng01 已提交
160
def DeiT_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
161
    model = VisionTransformer(
162 163 164 165 166 167 168 169
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
170 171 172 173 174
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
175 176 177
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
178 179
def DeiT_tiny_distilled_patch16_224(pretrained=False, use_ssld=False,
                                    **kwargs):
jm_12138's avatar
jm_12138 已提交
180
    model = DistilledVisionTransformer(
181 182 183 184 185 186 187 188
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
189 190 191 192 193
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_tiny_distilled_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
194 195 196
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
197 198 199
def DeiT_small_distilled_patch16_224(pretrained=False,
                                     use_ssld=False,
                                     **kwargs):
jm_12138's avatar
jm_12138 已提交
200
    model = DistilledVisionTransformer(
201 202 203 204 205 206 207 208
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
209 210 211 212 213
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_small_distilled_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
214 215 216
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
217 218
def DeiT_base_distilled_patch16_224(pretrained=False, use_ssld=False,
                                    **kwargs):
jm_12138's avatar
jm_12138 已提交
219
    model = DistilledVisionTransformer(
220 221 222 223 224 225 226 227
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
228 229 230 231 232
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_distilled_patch16_224"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
233 234 235
    return model


C
cuicheng01 已提交
236
def DeiT_base_patch16_384(pretrained=False, use_ssld=False, **kwargs):
jm_12138's avatar
jm_12138 已提交
237
    model = VisionTransformer(
238 239 240 241 242 243 244 245 246
        img_size=384,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
247 248 249 250 251
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_patch16_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
252 253 254
    return model


littletomatodonkey's avatar
littletomatodonkey 已提交
255 256
def DeiT_base_distilled_patch16_384(pretrained=False, use_ssld=False,
                                    **kwargs):
jm_12138's avatar
jm_12138 已提交
257
    model = DistilledVisionTransformer(
258 259 260 261 262 263 264 265 266
        img_size=384,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        epsilon=1e-6,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
267 268 269 270 271
    _load_pretrained(
        pretrained,
        model,
        MODEL_URLS["DeiT_base_distilled_patch16_384"],
        use_ssld=use_ssld)
jm_12138's avatar
jm_12138 已提交
272
    return model