未验证 提交 e5f75295 编写于 作者: T Tingquan Gao 提交者: GitHub

Add SwinTransformer (#721)

* Add SwinTransformer_large

* Add SwinTransformer

* Optimize user experience

* Add the doc of SwinTransformer

* Modify the format of the corner mark
上级 fb3e0233
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
PaddleClas is a toolset for image classification tasks prepared for the industry and academia. It helps users train better computer vision models and apply them in real scenarios. PaddleClas is a toolset for image classification tasks prepared for the industry and academia. It helps users train better computer vision models and apply them in real scenarios.
**Recent update** **Recent update**
- 2021.05.14 Add `SwinTransformer` series pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 87.19%.
- 2021.04.15 Add `MixNet` and `ReXNet` pretrained models, `MixNet_L`'s Top-1 Acc on ImageNet-1k reaches 78.6% and `ReXNet_3_0` reaches 82.09%. - 2021.04.15 Add `MixNet` and `ReXNet` pretrained models, `MixNet_L`'s Top-1 Acc on ImageNet-1k reaches 78.6% and `ReXNet_3_0` reaches 82.09%.
- 2021.03.02 Add support for model quantization. - 2021.03.02 Add support for model quantization.
- 2021.02.01 Add `RepVGG` pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 79.65%. - 2021.02.01 Add `RepVGG` pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 79.65%.
...@@ -63,10 +65,11 @@ PaddleClas is a toolset for image classification tasks prepared for the industry ...@@ -63,10 +65,11 @@ PaddleClas is a toolset for image classification tasks prepared for the industry
- [Inception series](#Inception_series) - [Inception series](#Inception_series)
- [EfficientNet and ResNeXt101_wsl series](#EfficientNet_and_ResNeXt101_wsl_series) - [EfficientNet and ResNeXt101_wsl series](#EfficientNet_and_ResNeXt101_wsl_series)
- [ResNeSt and RegNet series](#ResNeSt_and_RegNet_series) - [ResNeSt and RegNet series](#ResNeSt_and_RegNet_series)
- [Transformer series](#Transformer) - [ViT and DeiT series](#ViT_and_DeiT)
- [RepVGG series](#RepVGG) - [RepVGG series](#RepVGG)
- [MixNet series](#MixNet) - [MixNet series](#MixNet)
- [ReXNet series](#ReXNet) - [ReXNet series](#ReXNet)
- [SwinTransformer series](#SwinTransformer)
- [Others](#Others) - [Others](#Others)
- HS-ResNet: arxiv link: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf). Code and models are coming soon! - HS-ResNet: arxiv link: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf). Code and models are coming soon!
- Model training/evaluation - Model training/evaluation
...@@ -353,10 +356,10 @@ Accuracy and inference time metrics of ResNeSt and RegNet series models are show ...@@ -353,10 +356,10 @@ Accuracy and inference time metrics of ResNeSt and RegNet series models are show
| RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) | | RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) |
<a name="Transformer"></a> <a name="ViT_and_DeiT"></a>
### Transformer series ### ViT and DeiT series
Accuracy and inference time metrics of ViT and DeiT series models are shown as follows. More detailed information can be refered to [Transformer series tutorial](./docs/en/models/Transformer_en.md). Accuracy and inference time metrics of ViT and DeiT series models are shown as follows. More detailed information can be refered to [ViT and DeiT series tutorial](./docs/en/models/ViT_and_DeiT_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | Download Address | | Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | Download Address |
...@@ -430,6 +433,25 @@ Accuracy and inference time metrics of ReXNet series models are shown as follows ...@@ -430,6 +433,25 @@ Accuracy and inference time metrics of ReXNet series models are shown as follows
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) | | ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) | | ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
<a name="SwinTransformer"></a>
### SwinTransformer
Accuracy and inference time metrics of SwinTransformer series models are shown as follows. More detailed information can be refered to [SwinTransformer series tutorial](./docs/en/models/SwinTransformer_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | Download Address |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | | | 4.5 | 28 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | | | 8.7 | 50 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | | | 15.4 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | | | 47.1 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | | | 15.4 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | | | 47.1 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | | | 34.5 | 197 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | | | 103.9 | 197 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams) |
[1]: Based on imagenet22k dataset pre-training, and then in imagenet1k dataset transfer learning.
<a name="Others"></a> <a name="Others"></a>
### Others ### Others
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
**近期更新** **近期更新**
- 2021.04.15 添加`MixNet``ReXNet`系列模型,在ImageNet-1k上`MixNet_L` 模型Top1 Acc可达78.6%,`ReXNet_3_0`模型可达82.09% - 2021.05.14 添加`SwinTransformer` 系列模型,在ImageNet-1k上,Top1 Acc可达87.19%
- 2021.04.15 添加`MixNet``ReXNet`系列模型,在ImageNet-1k上`MixNet_L`模型Top1 Acc可达78.6%,`ReXNet_3_0`模型可达82.09%
- 2021.03.02 添加分类模型量化方法与使用教程。 - 2021.03.02 添加分类模型量化方法与使用教程。
- 2021.02.01 添加`RepVGG`系列模型,在ImageNet-1k上Top-1 Acc可达79.65%。 - 2021.02.01 添加`RepVGG`系列模型,在ImageNet-1k上Top-1 Acc可达79.65%。
- 2021.01.27 添加`ViT``DeiT`模型,在ImageNet-1k上,`ViT`模型Top-1 Acc可达85.13%,`DeiT`模型可达85.1%。 - 2021.01.27 添加`ViT``DeiT`模型,在ImageNet-1k上,`ViT`模型Top-1 Acc可达85.13%,`DeiT`模型可达85.1%。
...@@ -65,10 +66,11 @@ ...@@ -65,10 +66,11 @@
- [Inception系列](#Inception系列) - [Inception系列](#Inception系列)
- [EfficientNet与ResNeXt101_wsl系列](#EfficientNet与ResNeXt101_wsl系列) - [EfficientNet与ResNeXt101_wsl系列](#EfficientNet与ResNeXt101_wsl系列)
- [ResNeSt与RegNet系列](#ResNeSt与RegNet系列) - [ResNeSt与RegNet系列](#ResNeSt与RegNet系列)
- [Transformer系列](#Transformer系列) - [ViT与DeiT系列](#ViT_and_DeiT系列)
- [RepVGG系列](#RepVGG系列) - [RepVGG系列](#RepVGG系列)
- [MixNet系列](#MixNet系列) - [MixNet系列](#MixNet系列)
- [ReXNet系列](#ReXNet系列) - [ReXNet系列](#ReXNet系列)
- [SwinTransformer系列](#SwinTransformer系列)
- [其他模型](#其他模型) - [其他模型](#其他模型)
- HS-ResNet: arxiv文章链接: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf)。 代码和预训练模型即将开源,敬请期待。 - HS-ResNet: arxiv文章链接: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf)。 代码和预训练模型即将开源,敬请期待。
- 模型训练/评估 - 模型训练/评估
...@@ -358,10 +360,10 @@ ResNeSt与RegNet系列模型的精度、速度指标如下表所示,更多关 ...@@ -358,10 +360,10 @@ ResNeSt与RegNet系列模型的精度、速度指标如下表所示,更多关
| RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) | | RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) |
<a name="Transformer系列"></a> <a name="ViT_and_DeiT系列"></a>
### Transformer系列 ### ViT_and_DeiT系列
ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标如下表所示. 更多关于该系列模型的介绍可以参考: [Transformer系列模型文档](./docs/zh_CN/models/Transformer.md) ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标如下表所示. 更多关于该系列模型的介绍可以参考: [ViT_and_DeiT系列模型文档](./docs/zh_CN/models/ViT_and_DeiT.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | 下载地址 | | 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | 下载地址 |
...@@ -434,6 +436,25 @@ ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列 ...@@ -434,6 +436,25 @@ ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) | | ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) | | ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
<a name="SwinTransformer系列"></a>
### SwinTransformer系列
关于SwinTransformer系列模型的精度、速度指标如下表所示,更多介绍可以参考:[SwinTransformer系列模型文档](./docs/zh_CN/models/SwinTransformer.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | 下载地址 |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | | | 4.5 | 28 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | | | 8.7 | 50 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | | | 15.4 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | | | 47.1 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | | | 15.4 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | | | 47.1 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | | | 34.5 | 197 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | | | 103.9 | 197 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams) |
[1]:基于ImageNet22k数据集预训练,然后在ImageNet1k数据集迁移学习得到。
<a name="其他模型"></a> <a name="其他模型"></a>
### 其他模型 ### 其他模型
......
mode: 'train'
ARCHITECTURE:
name: 'SwinTransformer_large_patch4_window12_384'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 384, 384]
use_mix: False
ls_epsilon: -1
LEARNING_RATE:
function: 'Piecewise'
params:
lr: 0.1
decay_epochs: [30, 60, 90]
gamma: 0.1
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000100
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [384, 384]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mode: 'train'
ARCHITECTURE:
name: 'SwinTransformer_large_patch4_window7_224'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 224, 224]
use_mix: False
ls_epsilon: -1
LEARNING_RATE:
function: 'Piecewise'
params:
lr: 0.1
decay_epochs: [30, 60, 90]
gamma: 0.1
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000100
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
# SwinTransformer
## Overview
Swin Transformer a new vision Transformer, that capably serves as a general-purpose backbone for computer vision. It is a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. [Paper](https://arxiv.org/abs/2103.14030)
## Accuracy, FLOPS and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | 0.812 | 0.955 | 4.5 | 28 |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | 0.832 | 0.962 | 8.7 | 50 |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | 0.835 | 0.965 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | 0.845 | 0.970 | 47.1 | 88 |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | 0.852 | 0.975 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | 0.864 | 0.980 | 47.1 | 88 |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | 0.863 | 0.979 | 34.5 | 197 |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | 0.873 | 0.982 | 103.9 | 197 |
[1]: Based on imagenet22k dataset pre-training, and then in imagenet1k dataset transfer learning.
**Note**: The difference of precision with reference from the difference of data preprocessing.
...@@ -9,27 +9,27 @@ DeiT(Data-efficient Image Transformers) series models were proposed by Facebook ...@@ -9,27 +9,27 @@ DeiT(Data-efficient Image Transformers) series models were proposed by Facebook
## Accuracy, FLOPS and Parameters ## Accuracy, FLOPS and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | | Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:| |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | | ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | |
| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | | ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | |
| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | | ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | |
| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | | ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | |
| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | | ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | |
| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | | ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | |
| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | | ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | | Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:| |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | | DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | |
| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | | DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | |
| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | | DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | |
| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | | DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | |
| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | | DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | |
| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | | DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | |
| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | | DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | |
| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | | DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | |
Params, FLOPs, Inference speed and other information are coming soon. Params, FLOPs, Inference speed and other information are coming soon.
# SwinTransformer
## 概述
Swin Transformer 是一种新的视觉Transformer网络,可以用作计算机视觉领域的通用骨干网路。SwinTransformer由移动窗口(shifted windows)表示的层次Transformer结构组成。移动窗口将自注意计算限制在非重叠的局部窗口上,同时允许跨窗口连接,从而提高了网络性能。[论文地址](https://arxiv.org/abs/2103.14030)
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | 0.812 | 0.955 | 4.5 | 28 |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | 0.832 | 0.962 | 8.7 | 50 |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | 0.835 | 0.965 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | 0.845 | 0.970 | 47.1 | 88 |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | 0.852 | 0.975 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | 0.864 | 0.980 | 47.1 | 88 |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | 0.863 | 0.979 | 34.5 | 197 |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | 0.873 | 0.982 | 103.9 | 197 |
[1]:基于ImageNet22k数据集预训练,然后在ImageNet1k数据集迁移学习得到。
**注**:与Reference的精度差异源于数据预处理不同。
...@@ -11,26 +11,26 @@ DeiT(Data-efficient Image Transformers)系列模型是由FaceBook在2020年 ...@@ -11,26 +11,26 @@ DeiT(Data-efficient Image Transformers)系列模型是由FaceBook在2020年
## 精度、FLOPS和参数量 ## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | | Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:| |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | | ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | |
| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | | ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | |
| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | | ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | |
| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | | ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | |
| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | | ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | |
| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | | ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | |
| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | | ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | | Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:| |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | | DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | |
| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | | DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | |
| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | | DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | |
| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | | DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | |
| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | | DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | |
| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | | DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | |
| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | | DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | |
| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | | DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | |
关于Params、FLOPs、Inference speed等信息,敬请期待。 关于Params、FLOPs、Inference speed等信息,敬请期待。
...@@ -32,7 +32,10 @@ __dir__ = os.path.dirname(__file__) ...@@ -32,7 +32,10 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import argparse import argparse
import shutil import shutil
import textwrap
from difflib import SequenceMatcher
from prettytable import PrettyTable
import cv2 import cv2
import numpy as np import numpy as np
import tarfile import tarfile
...@@ -45,57 +48,148 @@ __all__ = ['PaddleClas'] ...@@ -45,57 +48,148 @@ __all__ = ['PaddleClas']
BASE_DIR = os.path.expanduser("~/.paddleclas/") BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, 'inference_model') BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, 'inference_model')
BASE_IMAGES_DIR = os.path.join(BASE_DIR, 'images') BASE_IMAGES_DIR = os.path.join(BASE_DIR, 'images')
model_series = {
model_names = { "AlexNet": ["AlexNet"],
'Xception71', 'SE_ResNeXt101_32x4d', 'ShuffleNetV2_x0_5', 'ResNet34', "DarkNet": ["DarkNet53"],
'ShuffleNetV2_x2_0', 'ResNeXt101_32x4d', 'HRNet_W48_C_ssld', "DeiT": [
'ResNeSt50_fast_1s1x64d', 'MobileNetV2_x2_0', 'MobileNetV3_large_x1_0', 'DeiT_base_distilled_patch16_224', 'DeiT_base_distilled_patch16_384',
'Fix_ResNeXt101_32x48d_wsl', 'MobileNetV2_ssld', 'ResNeXt101_vd_64x4d', 'DeiT_base_patch16_224', 'DeiT_base_patch16_384',
'ResNet34_vd_ssld', 'MobileNetV3_small_x1_0', 'VGG11', 'DeiT_small_distilled_patch16_224', 'DeiT_small_patch16_224',
'ResNeXt50_vd_32x4d', 'MobileNetV3_large_x1_25', 'DeiT_tiny_distilled_patch16_224', 'DeiT_tiny_patch16_224'
'MobileNetV3_large_x1_0_ssld', 'MobileNetV2_x0_75', ],
'MobileNetV3_small_x0_35', 'MobileNetV1_x0_75', 'MobileNetV1_ssld', "DenseNet": [
'ResNeXt50_32x4d', 'GhostNet_x1_3_ssld', 'Res2Net101_vd_26w_4s', "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
'ResNet152', 'Xception65', 'EfficientNetB0', 'ResNet152_vd', 'HRNet_W18_C', "DenseNet264"
'Res2Net50_14w_8s', 'ShuffleNetV2_x0_25', 'HRNet_W64_C', ],
'Res2Net50_vd_26w_4s_ssld', 'HRNet_W18_C_ssld', 'ResNet18_vd', "DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
'ResNeXt101_32x16d_wsl', 'SE_ResNeXt50_32x4d', 'SqueezeNet1_1', "EfficientNet": [
'SENet154_vd', 'SqueezeNet1_0', 'GhostNet_x1_0', 'ResNet50_vc', 'DPN98', "EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
'HRNet_W48_C', 'DenseNet264', 'SE_ResNet34_vd', 'HRNet_W44_C', "EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
'MobileNetV3_small_x1_25', 'MobileNetV1_x0_5', 'ResNet200_vd', 'VGG13', "EfficientNetB6", "EfficientNetB7"
'EfficientNetB3', 'EfficientNetB2', 'ShuffleNetV2_x0_33', ],
'MobileNetV3_small_x0_75', 'ResNeXt152_vd_32x4d', 'ResNeXt101_32x32d_wsl', "GhostNet":
'ResNet18', 'MobileNetV3_large_x0_35', 'Res2Net50_26w_4s', ["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
'MobileNetV2_x0_5', 'EfficientNetB0_small', 'ResNet101_vd_ssld', "HRNet": [
'EfficientNetB6', 'EfficientNetB1', 'EfficientNetB7', 'ResNeSt50', "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
'ShuffleNetV2_x1_0', 'MobileNetV3_small_x1_0_ssld', 'InceptionV4', "HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
'GhostNet_x0_5', 'SE_HRNet_W64_C_ssld', 'ResNet50_ACNet_deploy', "HRNet_W48_C_ssld"
'Xception41', 'ResNet50', 'Res2Net200_vd_26w_4s_ssld', ],
'Xception41_deeplab', 'SE_ResNet18_vd', 'SE_ResNeXt50_vd_32x4d', "Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
'HRNet_W30_C', 'HRNet_W40_C', 'VGG19', 'Res2Net200_vd_26w_4s', "MobileNetV1": [
'ResNeXt101_32x8d_wsl', 'ResNet50_vd', 'ResNeXt152_64x4d', 'DarkNet53', "MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
'ResNet50_vd_ssld', 'ResNeXt101_64x4d', 'MobileNetV1_x0_25', "MobileNetV1", "MobileNetV1_ssld"
'Xception65_deeplab', 'AlexNet', 'ResNet101', 'DenseNet121', ],
'ResNet50_vd_v2', 'Res2Net50_vd_26w_4s', 'ResNeXt101_32x48d_wsl', "MobileNetV2": [
'MobileNetV3_large_x0_5', 'MobileNetV2_x0_25', 'DPN92', 'ResNet101_vd', "MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x0_75",
'MobileNetV2_x1_5', 'DPN131', 'ResNeXt50_vd_64x4d', 'ShuffleNetV2_x1_5', "MobileNetV2", "MobileNetV2_x1_5", "MobileNetV2_x2_0",
'ResNet34_vd', 'MobileNetV1', 'ResNeXt152_vd_64x4d', 'DPN107', 'VGG16', "MobileNetV2_ssld"
'ResNeXt50_64x4d', 'RegNetX_4GF', 'DenseNet161', 'GhostNet_x1_3', ],
'HRNet_W32_C', 'Fix_ResNet50_vd_ssld_v2', 'Res2Net101_vd_26w_4s_ssld', "MobileNetV3": [
'DenseNet201', 'DPN68', 'EfficientNetB4', 'ResNeXt152_32x4d', "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
'InceptionV3', 'ShuffleNetV2_swish', 'GoogLeNet', 'ResNet50_vd_ssld_v2', "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
'SE_ResNet50_vd', 'MobileNetV2', 'ResNeXt101_vd_32x4d', "MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
'MobileNetV3_large_x0_75', 'MobileNetV3_small_x0_5', 'DenseNet169', "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
'EfficientNetB5', 'DeiT_base_distilled_patch16_224', "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
'DeiT_base_distilled_patch16_384', 'DeiT_base_patch16_224', "MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
'DeiT_base_patch16_384', 'DeiT_small_distilled_patch16_224', ],
'DeiT_small_patch16_224', 'DeiT_tiny_distilled_patch16_224', "RegNet": ["RegNetX_4GF"],
'DeiT_tiny_patch16_224', 'ViT_base_patch16_224', 'ViT_base_patch16_384', "Res2Net": [
'ViT_base_patch32_384', 'ViT_large_patch16_224', 'ViT_large_patch16_384', "Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
'ViT_large_patch32_384', 'ViT_small_patch16_224' "Res2Net200_vd_26w_4s", "Res2Net101_vd_26w_4s",
"Res2Net50_vd_26w_4s_ssld", "Res2Net101_vd_26w_4s_ssld",
"Res2Net200_vd_26w_4s_ssld"
],
"ResNeSt": ["ResNeSt50", "ResNeSt50_fast_1s1x64d"],
"ResNet": [
"ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50",
"ResNet50_vc", "ResNet50_vd", "ResNet50_vd_v2", "ResNet101",
"ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd",
"ResNet34_vd_ssld", "ResNet50_vd_ssld", "ResNet50_vd_ssld_v2",
"ResNet101_vd_ssld", "Fix_ResNet50_vd_ssld_v2", "ResNet50_ACNet_deploy"
],
"ResNeXt": [
"ResNeXt50_32x4d", "ResNeXt50_vd_32x4d", "ResNeXt50_64x4d",
"ResNeXt50_vd_64x4d", "ResNeXt101_32x4d", "ResNeXt101_vd_32x4d",
"ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl",
"ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl",
"Fix_ResNeXt101_32x48d_wsl", "ResNeXt101_64x4d", "ResNeXt101_vd_64x4d",
"ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
"ResNeXt152_vd_64x4d"
],
"SENet": [
"SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
"SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
"SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_32x4d"
],
"ShuffleNetV2": [
"ShuffleNetV2_swish", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33",
"ShuffleNetV2_x0_5", "ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5",
"ShuffleNetV2_x2_0"
],
"SqueezeNet": ["SqueezeNet1_0", "SqueezeNet1_1"],
"SwinTransformer": [
"SwinTransformer_large_patch4_window7_224_22kto1k",
"SwinTransformer_large_patch4_window12_384_22kto1k",
"SwinTransformer_base_patch4_window7_224_22kto1k",
"SwinTransformer_base_patch4_window12_384_22kto1k",
"SwinTransformer_base_patch4_window12_384",
"SwinTransformer_base_patch4_window7_224",
"SwinTransformer_small_patch4_window7_224",
"SwinTransformer_tiny_patch4_window7_224"
],
"VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
"VisionTransformer": [
"ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
"ViT_large_patch16_224", "ViT_large_patch16_384",
"ViT_large_patch32_384", "ViT_small_patch16_224"
],
"Xception": [
"Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab",
"Xception71"
]
} }
class ModelNameError(Exception):
""" ModelNameError
"""
def __init__(self, message=''):
super().__init__(message)
def print_info():
table = PrettyTable(['Series', 'Name'])
for series in model_series:
names = textwrap.fill(" ".join(model_series[series]), width=100)
table.add_row([series, names])
print('Inference models that Paddle provides are listed as follows:')
print(table)
def get_model_names():
model_names = []
for series in model_series:
model_names += model_series[series]
return model_names
def similar_architectures(name='', names=[], thresh=0.1, topk=10):
"""
inferred similar architectures
"""
scores = []
for idx, n in enumerate(names):
if n.startswith('__'):
continue
score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
if score > thresh:
scores.append((idx, score))
scores.sort(key=lambda x: x[1], reverse=True)
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
return similar_names
def download_with_progressbar(url, save_path): def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0)) total_size_in_bytes = int(response.headers.get('content-length', 0))
...@@ -227,17 +321,26 @@ def parse_args(mMain=True, add_help=True): ...@@ -227,17 +321,26 @@ def parse_args(mMain=True, add_help=True):
class PaddleClas(object): class PaddleClas(object):
print('Inference models that Paddle provides are listed as follows:\n\n{}'. print_info()
format(model_names), '\n')
def __init__(self, **kwargs): def __init__(self, **kwargs):
model_names = get_model_names()
process_params = parse_args(mMain=False, add_help=False) process_params = parse_args(mMain=False, add_help=False)
process_params.__dict__.update(**kwargs) process_params.__dict__.update(**kwargs)
if not os.path.exists(process_params.model_file): if not os.path.exists(process_params.model_file):
if process_params.model_name is None: if process_params.model_name is None:
raise Exception( raise ModelNameError(
'Please input model name that you want to use!') 'Please input model name that you want to use!')
similar_names = similar_architectures(process_params.model_name,
model_names)
model_list = ', '.join(similar_names)
if process_params.model_name not in similar_names:
err = "{} is not exist! Maybe you want: [{}]" \
"".format(process_params.model_name, model_list)
raise ModelNameError(err)
if process_params.model_name in model_names: if process_params.model_name in model_names:
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar'.format( url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar'.format(
process_params.model_name) process_params.model_name)
......
...@@ -47,6 +47,6 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT ...@@ -47,6 +47,6 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT
from .distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384 from .distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from .repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B3, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g2, RepVGG_B2g4, RepVGG_B3g2, RepVGG_B3g4 from .repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B3, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g2, RepVGG_B2g4, RepVGG_B3g2, RepVGG_B3g4
from .swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384 from .swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from .mixnet import MixNet_S, MixNet_M, MixNet_L from .mixnet import MixNet_S, MixNet_M, MixNet_L
from .rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0 from .rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
...@@ -759,3 +759,24 @@ def SwinTransformer_base_patch4_window12_384(**args): ...@@ -759,3 +759,24 @@ def SwinTransformer_base_patch4_window12_384(**args):
drop_path_rate=0.5, # NOTE: do not appear in offical code drop_path_rate=0.5, # NOTE: do not appear in offical code
**args) **args)
return model return model
def SwinTransformer_large_patch4_window7_224(**args):
model = SwinTransformer(
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
**args)
return model
def SwinTransformer_large_patch4_window12_384(**args):
model = SwinTransformer(
img_size=384,
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
**args)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册