Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
747a6598
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
747a6598
编写于
1月 21, 2021
作者:
jm_12138
提交者:
GitHub
1月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ViT model (#570)
* Add the ViT model
上级
d0ecff1b
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
951 addition
and
1 deletion
+951
-1
configs/VisionTransformer/ViT_base_patch16_224.yaml
configs/VisionTransformer/ViT_base_patch16_224.yaml
+74
-0
configs/VisionTransformer/ViT_base_patch16_384.yaml
configs/VisionTransformer/ViT_base_patch16_384.yaml
+74
-0
configs/VisionTransformer/ViT_base_patch32_384.yaml
configs/VisionTransformer/ViT_base_patch32_384.yaml
+74
-0
configs/VisionTransformer/ViT_huge_patch16_224.yaml
configs/VisionTransformer/ViT_huge_patch16_224.yaml
+74
-0
configs/VisionTransformer/ViT_huge_patch32_384.yaml
configs/VisionTransformer/ViT_huge_patch32_384.yaml
+74
-0
configs/VisionTransformer/ViT_large_patch16_224.yaml
configs/VisionTransformer/ViT_large_patch16_224.yaml
+74
-0
configs/VisionTransformer/ViT_large_patch16_384.yaml
configs/VisionTransformer/ViT_large_patch16_384.yaml
+74
-0
configs/VisionTransformer/ViT_large_patch32_384.yaml
configs/VisionTransformer/ViT_large_patch32_384.yaml
+74
-0
configs/VisionTransformer/ViT_small_patch16_224.yaml
configs/VisionTransformer/ViT_small_patch16_224.yaml
+74
-0
ppcls/modeling/architectures/__init__.py
ppcls/modeling/architectures/__init__.py
+1
-1
ppcls/modeling/architectures/vision_transformer.py
ppcls/modeling/architectures/vision_transformer.py
+284
-0
未找到文件。
configs/VisionTransformer/ViT_base_patch16_224.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_base_patch16_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.005
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
48
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
48
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
248
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
configs/VisionTransformer/ViT_base_patch16_384.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_base_patch16_384'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
384
,
384
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.005
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
48
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
384
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
48
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
384
-
CropImage
:
size
:
384
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
configs/VisionTransformer/ViT_base_patch32_384.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_base_patch32_384'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
384
,
384
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.005
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
48
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
384
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
48
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
384
-
CropImage
:
size
:
384
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
\ No newline at end of file
configs/VisionTransformer/ViT_huge_patch16_224.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_huge_patch16_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.001
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
16
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
16
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
248
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/VisionTransformer/ViT_huge_patch32_384.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_huge_patch32_384'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
384
,
384
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.001
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
16
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
384
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
16
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
384
-
CropImage
:
size
:
384
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/VisionTransformer/ViT_large_patch16_224.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_large_patch16_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.003
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
32
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
32
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
248
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
configs/VisionTransformer/ViT_large_patch16_384.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_large_patch16_384'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
384
,
384
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.003
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
32
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
384
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
32
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
384
-
CropImage
:
size
:
384
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
configs/VisionTransformer/ViT_large_patch32_384.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_large_patch32_384'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
384
,
384
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.003
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
32
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
384
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
32
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
384
-
CropImage
:
size
:
384
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
\ No newline at end of file
configs/VisionTransformer/ViT_small_patch16_224.yaml
0 → 100644
浏览文件 @
747a6598
mode
:
'
train'
ARCHITECTURE
:
name
:
'
ViT_small_patch16_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Cosine'
params
:
lr
:
0.01
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
size
:
248
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
ppcls/modeling/architectures/__init__.py
浏览文件 @
747a6598
...
@@ -43,5 +43,5 @@ from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
...
@@ -43,5 +43,5 @@ from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from
.vgg
import
VGG11
,
VGG13
,
VGG16
,
VGG19
from
.vgg
import
VGG11
,
VGG13
,
VGG16
,
VGG19
from
.darknet
import
DarkNet53
from
.darknet
import
DarkNet53
from
.regnet
import
RegNetX_200MF
,
RegNetX_4GF
,
RegNetX_32GF
,
RegNetY_200MF
,
RegNetY_4GF
,
RegNetY_32GF
from
.regnet
import
RegNetX_200MF
,
RegNetX_4GF
,
RegNetX_32GF
,
RegNetY_200MF
,
RegNetY_4GF
,
RegNetY_32GF
from
.vision_transformer
import
ViT_small_patch16_224
,
ViT_base_patch16_224
,
ViT_base_patch16_384
,
ViT_base_patch32_384
,
ViT_large_patch16_224
,
ViT_large_patch16_384
,
ViT_large_patch32_384
,
ViT_huge_patch16_224
,
ViT_huge_patch32_384
from
.distillation_models
import
ResNet50_vd_distill_MobileNetV3_large_x1_0
from
.distillation_models
import
ResNet50_vd_distill_MobileNetV3_large_x1_0
ppcls/modeling/architectures/vision_transformer.py
0 → 100644
浏览文件 @
747a6598
""" Vision Transformer (ViT) in Paddle
A Paddle implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
"""
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
__all__
=
[
"VisionTransformer"
,
"ViT_small_patch16_224"
,
"ViT_base_patch16_224"
,
"ViT_base_patch16_384"
,
"ViT_base_patch32_384"
,
"ViT_large_patch16_224"
,
"ViT_large_patch16_384"
,
"ViT_large_patch32_384"
,
"ViT_huge_patch16_224"
,
"ViT_huge_patch32_384"
]
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
to_2tuple
(
x
):
return
tuple
([
x
]
*
2
)
def
drop_path
(
x
,
drop_prob
=
0.
,
training
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
,
dtype
=
x
.
dtype
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Layer
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
Identity
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
class
Mlp
(
nn
.
Layer
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias_attr
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
((
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)).
transpose
((
2
,
0
,
3
,
1
,
4
))
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
.
matmul
(
k
.
transpose
((
0
,
1
,
3
,
2
))))
*
self
.
scale
attn
=
nn
.
functional
.
softmax
(
attn
,
axis
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
.
matmul
(
v
)).
transpose
((
0
,
2
,
1
,
3
)).
reshape
((
B
,
N
,
C
))
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-5
):
super
().
__init__
()
self
.
norm1
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
Identity
()
self
.
norm2
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Layer
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
\
(
img_size
[
0
]
//
patch_size
[
0
])
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2D
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
((
0
,
2
,
1
))
return
x
class
VisionTransformer
(
nn
.
Layer
):
""" Vision Transformer with support for patch input
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
class_dim
=
1000
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-5
,
**
args
):
super
().
__init__
()
self
.
class_dim
=
class_dim
self
.
num_features
=
self
.
embed_dim
=
embed_dim
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
pos_embed
=
self
.
create_parameter
(
shape
=
(
1
,
num_patches
+
1
,
embed_dim
),
default_initializer
=
zeros_
)
self
.
add_parameter
(
"pos_embed"
,
self
.
pos_embed
)
self
.
cls_token
=
self
.
create_parameter
(
shape
=
(
1
,
1
,
embed_dim
),
default_initializer
=
zeros_
)
self
.
add_parameter
(
"cls_token"
,
self
.
cls_token
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
for
x
in
paddle
.
linspace
(
0
,
drop_path_rate
,
depth
)]
self
.
blocks
=
nn
.
LayerList
([
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
)
for
i
in
range
(
depth
)])
self
.
norm
=
eval
(
norm_layer
)(
embed_dim
,
epsilon
=
epsilon
)
# Classifier head
self
.
head
=
nn
.
Linear
(
embed_dim
,
class_dim
)
if
class_dim
>
0
else
Identity
()
trunc_normal_
(
self
.
pos_embed
)
trunc_normal_
(
self
.
cls_token
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward_features
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
=
self
.
patch_embed
(
x
)
cls_tokens
=
self
.
cls_token
.
expand
((
B
,
-
1
,
-
1
))
x
=
paddle
.
concat
((
cls_tokens
,
x
),
axis
=
1
)
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
x
=
self
.
norm
(
x
)
return
x
[:,
0
]
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head
(
x
)
return
x
def
ViT_small_patch16_224
(
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
16
,
embed_dim
=
768
,
depth
=
8
,
num_heads
=
8
,
mlp_ratio
=
3
,
qk_scale
=
768
**-
0.5
,
**
kwargs
)
return
model
def
ViT_base_patch16_224
(
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
16
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
epsilon
=
1e-6
,
**
kwargs
)
return
model
def
ViT_base_patch16_384
(
**
kwargs
):
model
=
VisionTransformer
(
img_size
=
384
,
patch_size
=
16
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
epsilon
=
1e-6
,
**
kwargs
)
return
model
def
ViT_base_patch32_384
(
**
kwargs
):
model
=
VisionTransformer
(
img_size
=
384
,
patch_size
=
32
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
epsilon
=
1e-6
,
**
kwargs
)
return
model
def
ViT_large_patch16_224
(
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
16
,
embed_dim
=
1024
,
depth
=
24
,
num_heads
=
16
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
epsilon
=
1e-6
,
**
kwargs
)
return
model
def
ViT_large_patch16_384
(
**
kwargs
):
model
=
VisionTransformer
(
img_size
=
384
,
patch_size
=
16
,
embed_dim
=
1024
,
depth
=
24
,
num_heads
=
16
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
epsilon
=
1e-6
,
**
kwargs
)
return
model
def
ViT_large_patch32_384
(
**
kwargs
):
model
=
VisionTransformer
(
img_size
=
384
,
patch_size
=
32
,
embed_dim
=
1024
,
depth
=
24
,
num_heads
=
16
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
epsilon
=
1e-6
,
**
kwargs
)
return
model
def
ViT_huge_patch16_224
(
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
16
,
embed_dim
=
1280
,
depth
=
32
,
num_heads
=
16
,
mlp_ratio
=
4
,
**
kwargs
)
return
model
def
ViT_huge_patch32_384
(
**
kwargs
):
model
=
VisionTransformer
(
img_size
=
384
,
patch_size
=
32
,
embed_dim
=
1280
,
depth
=
32
,
num_heads
=
16
,
mlp_ratio
=
4
,
**
kwargs
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录