From 747a659891acf6c04ee04568922605f283d02394 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Thu, 21 Jan 2021 13:05:05 +0800 Subject: [PATCH] Add ViT model (#570) * Add the ViT model --- .../ViT_base_patch16_224.yaml | 74 +++++ .../ViT_base_patch16_384.yaml | 74 +++++ .../ViT_base_patch32_384.yaml | 74 +++++ .../ViT_huge_patch16_224.yaml | 74 +++++ .../ViT_huge_patch32_384.yaml | 74 +++++ .../ViT_large_patch16_224.yaml | 74 +++++ .../ViT_large_patch16_384.yaml | 74 +++++ .../ViT_large_patch32_384.yaml | 74 +++++ .../ViT_small_patch16_224.yaml | 74 +++++ ppcls/modeling/architectures/__init__.py | 2 +- .../architectures/vision_transformer.py | 284 ++++++++++++++++++ 11 files changed, 951 insertions(+), 1 deletion(-) create mode 100644 configs/VisionTransformer/ViT_base_patch16_224.yaml create mode 100644 configs/VisionTransformer/ViT_base_patch16_384.yaml create mode 100644 configs/VisionTransformer/ViT_base_patch32_384.yaml create mode 100644 configs/VisionTransformer/ViT_huge_patch16_224.yaml create mode 100644 configs/VisionTransformer/ViT_huge_patch32_384.yaml create mode 100644 configs/VisionTransformer/ViT_large_patch16_224.yaml create mode 100644 configs/VisionTransformer/ViT_large_patch16_384.yaml create mode 100644 configs/VisionTransformer/ViT_large_patch32_384.yaml create mode 100644 configs/VisionTransformer/ViT_small_patch16_224.yaml create mode 100644 ppcls/modeling/architectures/vision_transformer.py diff --git a/configs/VisionTransformer/ViT_base_patch16_224.yaml b/configs/VisionTransformer/ViT_base_patch16_224.yaml new file mode 100644 index 00000000..ec394b0f --- /dev/null +++ b/configs/VisionTransformer/ViT_base_patch16_224.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_base_patch16_224' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.005 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 48 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + +VALID: + batch_size: 48 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 248 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: diff --git a/configs/VisionTransformer/ViT_base_patch16_384.yaml b/configs/VisionTransformer/ViT_base_patch16_384.yaml new file mode 100644 index 00000000..c09246d4 --- /dev/null +++ b/configs/VisionTransformer/ViT_base_patch16_384.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_base_patch16_384' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 384, 384] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.005 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 48 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 384 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + +VALID: + batch_size: 48 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: diff --git a/configs/VisionTransformer/ViT_base_patch32_384.yaml b/configs/VisionTransformer/ViT_base_patch32_384.yaml new file mode 100644 index 00000000..fdd67e85 --- /dev/null +++ b/configs/VisionTransformer/ViT_base_patch32_384.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_base_patch32_384' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 384, 384] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.005 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 48 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 384 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + +VALID: + batch_size: 48 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: \ No newline at end of file diff --git a/configs/VisionTransformer/ViT_huge_patch16_224.yaml b/configs/VisionTransformer/ViT_huge_patch16_224.yaml new file mode 100644 index 00000000..33d8225e --- /dev/null +++ b/configs/VisionTransformer/ViT_huge_patch16_224.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_huge_patch16_224' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.001 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 16 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 16 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 248 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/VisionTransformer/ViT_huge_patch32_384.yaml b/configs/VisionTransformer/ViT_huge_patch32_384.yaml new file mode 100644 index 00000000..9f06a813 --- /dev/null +++ b/configs/VisionTransformer/ViT_huge_patch32_384.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_huge_patch32_384' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 384, 384] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.001 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 16 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 384 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 16 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/VisionTransformer/ViT_large_patch16_224.yaml b/configs/VisionTransformer/ViT_large_patch16_224.yaml new file mode 100644 index 00000000..5e60e10e --- /dev/null +++ b/configs/VisionTransformer/ViT_large_patch16_224.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_large_patch16_224' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.003 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + +VALID: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 248 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: diff --git a/configs/VisionTransformer/ViT_large_patch16_384.yaml b/configs/VisionTransformer/ViT_large_patch16_384.yaml new file mode 100644 index 00000000..0d5fe9f9 --- /dev/null +++ b/configs/VisionTransformer/ViT_large_patch16_384.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_large_patch16_384' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 384, 384] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.003 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 384 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + +VALID: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: diff --git a/configs/VisionTransformer/ViT_large_patch32_384.yaml b/configs/VisionTransformer/ViT_large_patch32_384.yaml new file mode 100644 index 00000000..8fdd98fb --- /dev/null +++ b/configs/VisionTransformer/ViT_large_patch32_384.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_large_patch32_384' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 384, 384] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.003 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 384 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + +VALID: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: \ No newline at end of file diff --git a/configs/VisionTransformer/ViT_small_patch16_224.yaml b/configs/VisionTransformer/ViT_small_patch16_224.yaml new file mode 100644 index 00000000..b8f5c0d3 --- /dev/null +++ b/configs/VisionTransformer/ViT_small_patch16_224.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ViT_small_patch16_224' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: -1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.01 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000100 + +TRAIN: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + size: 248 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py index c7e5aaef..aa2b93bd 100644 --- a/ppcls/modeling/architectures/__init__.py +++ b/ppcls/modeling/architectures/__init__.py @@ -43,5 +43,5 @@ from .squeezenet import SqueezeNet1_0, SqueezeNet1_1 from .vgg import VGG11, VGG13, VGG16, VGG19 from .darknet import DarkNet53 from .regnet import RegNetX_200MF, RegNetX_4GF, RegNetX_32GF, RegNetY_200MF, RegNetY_4GF, RegNetY_32GF - +from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT_base_patch16_384, ViT_base_patch32_384, ViT_large_patch16_224, ViT_large_patch16_384, ViT_large_patch32_384, ViT_huge_patch16_224, ViT_huge_patch32_384 from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0 diff --git a/ppcls/modeling/architectures/vision_transformer.py b/ppcls/modeling/architectures/vision_transformer.py new file mode 100644 index 00000000..cd8bbdea --- /dev/null +++ b/ppcls/modeling/architectures/vision_transformer.py @@ -0,0 +1,284 @@ +""" Vision Transformer (ViT) in Paddle +A Paddle implement of Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 +The official jax code is released and available at https://github.com/google-research/vision_transformer +""" +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant + + +__all__ = [ + "VisionTransformer", + "ViT_small_patch16_224", + "ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384", + "ViT_large_patch16_224", "ViT_large_patch16_384", "ViT_large_patch32_384", + "ViT_huge_patch16_224", "ViT_huge_patch32_384" +] + + +trunc_normal_ = TruncatedNormal(std=.02) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def to_2tuple(x): + return tuple([x] * 2) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Layer): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Layer): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((B, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer='nn.LayerNorm', epsilon=1e-5): + super().__init__() + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * \ + (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2D(in_chans, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + x = self.proj(x).flatten(2).transpose((0, 2, 1)) + return x + + +class VisionTransformer(nn.Layer): + """ Vision Transformer with support for patch input + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, class_dim=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer='nn.LayerNorm', epsilon=1e-5, **args): + super().__init__() + self.class_dim = class_dim + + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = self.create_parameter( + shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_) + self.add_parameter("pos_embed", self.pos_embed) + self.cls_token = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("cls_token", self.cls_token) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x for x in paddle.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.LayerList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, epsilon=epsilon) + for i in range(depth)]) + + self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon) + + # Classifier head + self.head = nn.Linear( + embed_dim, class_dim) if class_dim > 0 else Identity() + + trunc_normal_(self.pos_embed) + trunc_normal_(self.cls_token) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + cls_tokens = self.cls_token.expand((B, -1, -1)) + x = paddle.concat((cls_tokens, x), axis=1) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def ViT_small_patch16_224(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, qk_scale=768**-0.5, **kwargs) + return model + + +def ViT_base_patch16_224(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + epsilon=1e-6, **kwargs) + return model + + +def ViT_base_patch16_384(**kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, epsilon=1e-6, **kwargs) + return model + + +def ViT_base_patch32_384(**kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, epsilon=1e-6, **kwargs) + return model + + +def ViT_large_patch16_224(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + epsilon=1e-6, **kwargs) + return model + + +def ViT_large_patch16_384(**kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, epsilon=1e-6, **kwargs) + return model + + +def ViT_large_patch32_384(**kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, epsilon=1e-6, **kwargs) + return model + + +def ViT_huge_patch16_224(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) + return model + + +def ViT_huge_patch32_384(**kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) + return model -- GitLab