diff --git a/README.md b/README.md
index 5c41a1e4a53e0540f2feb42cdd3d25863ce35d24..2ddbb254341a7a3b0f924bda682c058f0c58a10c 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,8 @@
PaddleClas is a toolset for image classification tasks prepared for the industry and academia. It helps users train better computer vision models and apply them in real scenarios.
**Recent update**
+
+- 2021.05.14 Add `SwinTransformer` series pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 87.19%.
- 2021.04.15 Add `MixNet` and `ReXNet` pretrained models, `MixNet_L`'s Top-1 Acc on ImageNet-1k reaches 78.6% and `ReXNet_3_0` reaches 82.09%.
- 2021.03.02 Add support for model quantization.
- 2021.02.01 Add `RepVGG` pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 79.65%.
@@ -63,10 +65,11 @@ PaddleClas is a toolset for image classification tasks prepared for the industry
- [Inception series](#Inception_series)
- [EfficientNet and ResNeXt101_wsl series](#EfficientNet_and_ResNeXt101_wsl_series)
- [ResNeSt and RegNet series](#ResNeSt_and_RegNet_series)
- - [Transformer series](#Transformer)
+ - [ViT and DeiT series](#ViT_and_DeiT)
- [RepVGG series](#RepVGG)
- [MixNet series](#MixNet)
- [ReXNet series](#ReXNet)
+ - [SwinTransformer series](#SwinTransformer)
- [Others](#Others)
- HS-ResNet: arxiv link: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf). Code and models are coming soon!
- Model training/evaluation
@@ -353,10 +356,10 @@ Accuracy and inference time metrics of ResNeSt and RegNet series models are show
| RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) |
-
-### Transformer series
+
+### ViT and DeiT series
-Accuracy and inference time metrics of ViT and DeiT series models are shown as follows. More detailed information can be refered to [Transformer series tutorial](./docs/en/models/Transformer_en.md).
+Accuracy and inference time metrics of ViT and DeiT series models are shown as follows. More detailed information can be refered to [ViT and DeiT series tutorial](./docs/en/models/ViT_and_DeiT_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | Flops(G) | Params(M) | Download Address |
@@ -430,6 +433,25 @@ Accuracy and inference time metrics of ReXNet series models are shown as follows
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
+
+
+### SwinTransformer
+
+Accuracy and inference time metrics of SwinTransformer series models are shown as follows. More detailed information can be refered to [SwinTransformer series tutorial](./docs/en/models/SwinTransformer_en.md).
+
+| Model | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | Flops(G) | Params(M) | Download Address |
+| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
+| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | | | 4.5 | 28 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams) |
+| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | | | 8.7 | 50 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | | | 15.4 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | | | 47.1 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window7_224[1] | 0.8487 | 0.9746 | | | 15.4 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_22kto1k_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window12_384[1] | 0.8642 | 0.9807 | | | 47.1 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_22kto1k_pretrained.pdparams) |
+| SwinTransformer_large_patch4_window7_224[1] | 0.8596 | 0.9783 | | | 34.5 | 197 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams) |
+| SwinTransformer_large_patch4_window12_384[1] | 0.8719 | 0.9823 | | | 103.9 | 197 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams) |
+
+[1]: Based on imagenet22k dataset pre-training, and then in imagenet1k dataset transfer learning.
+
### Others
diff --git a/README_cn.md b/README_cn.md
index 602c9aacb7b2e10bd96b483dfc16490d30d13c37..e3291c263873339d80959a845c25e444b482633f 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -8,7 +8,8 @@
**近期更新**
-- 2021.04.15 添加`MixNet`和`ReXNet`系列模型,在ImageNet-1k上`MixNet_L` 模型Top1 Acc可达78.6%,`ReXNet_3_0`模型可达82.09%
+- 2021.05.14 添加`SwinTransformer` 系列模型,在ImageNet-1k上,Top1 Acc可达87.19%
+- 2021.04.15 添加`MixNet`和`ReXNet`系列模型,在ImageNet-1k上`MixNet_L`模型Top1 Acc可达78.6%,`ReXNet_3_0`模型可达82.09%
- 2021.03.02 添加分类模型量化方法与使用教程。
- 2021.02.01 添加`RepVGG`系列模型,在ImageNet-1k上Top-1 Acc可达79.65%。
- 2021.01.27 添加`ViT`与`DeiT`模型,在ImageNet-1k上,`ViT`模型Top-1 Acc可达85.13%,`DeiT`模型可达85.1%。
@@ -65,10 +66,11 @@
- [Inception系列](#Inception系列)
- [EfficientNet与ResNeXt101_wsl系列](#EfficientNet与ResNeXt101_wsl系列)
- [ResNeSt与RegNet系列](#ResNeSt与RegNet系列)
- - [Transformer系列](#Transformer系列)
+ - [ViT与DeiT系列](#ViT_and_DeiT系列)
- [RepVGG系列](#RepVGG系列)
- [MixNet系列](#MixNet系列)
- [ReXNet系列](#ReXNet系列)
+ - [SwinTransformer系列](#SwinTransformer系列)
- [其他模型](#其他模型)
- HS-ResNet: arxiv文章链接: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf)。 代码和预训练模型即将开源,敬请期待。
- 模型训练/评估
@@ -358,10 +360,10 @@ ResNeSt与RegNet系列模型的精度、速度指标如下表所示,更多关
| RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) |
-
-### Transformer系列
+
+### ViT_and_DeiT系列
-ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标如下表所示. 更多关于该系列模型的介绍可以参考: [Transformer系列模型文档](./docs/zh_CN/models/Transformer.md)。
+ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标如下表所示. 更多关于该系列模型的介绍可以参考: [ViT_and_DeiT系列模型文档](./docs/zh_CN/models/ViT_and_DeiT.md)。
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | Flops(G) | Params(M) | 下载地址 |
@@ -434,6 +436,25 @@ ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
+
+
+### SwinTransformer系列
+
+关于SwinTransformer系列模型的精度、速度指标如下表所示,更多介绍可以参考:[SwinTransformer系列模型文档](./docs/zh_CN/models/SwinTransformer.md)。
+
+| 模型 | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | Flops(G) | Params(M) | 下载地址 |
+| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
+| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | | | 4.5 | 28 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams) |
+| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | | | 8.7 | 50 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | | | 15.4 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | | | 47.1 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window7_224[1] | 0.8487 | 0.9746 | | | 15.4 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_22kto1k_pretrained.pdparams) |
+| SwinTransformer_base_patch4_window12_384[1] | 0.8642 | 0.9807 | | | 47.1 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_22kto1k_pretrained.pdparams) |
+| SwinTransformer_large_patch4_window7_224[1] | 0.8596 | 0.9783 | | | 34.5 | 197 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams) |
+| SwinTransformer_large_patch4_window12_384[1] | 0.8719 | 0.9823 | | | 103.9 | 197 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams) |
+
+[1]:基于ImageNet22k数据集预训练,然后在ImageNet1k数据集迁移学习得到。
+
### 其他模型
diff --git a/configs/SwinTransformer/SwinTransformer_large_patch4_window12_384.yaml b/configs/SwinTransformer/SwinTransformer_large_patch4_window12_384.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d05754d142ee6268b6b5b632eb3c2bf7cfd1eeb
--- /dev/null
+++ b/configs/SwinTransformer/SwinTransformer_large_patch4_window12_384.yaml
@@ -0,0 +1,72 @@
+mode: 'train'
+ARCHITECTURE:
+ name: 'SwinTransformer_large_patch4_window12_384'
+
+pretrained_model: ""
+model_save_dir: "./output/"
+classes_num: 1000
+total_images: 1281167
+save_interval: 1
+validate: True
+valid_interval: 1
+epochs: 120
+topk: 5
+image_shape: [3, 384, 384]
+
+use_mix: False
+ls_epsilon: -1
+
+LEARNING_RATE:
+ function: 'Piecewise'
+ params:
+ lr: 0.1
+ decay_epochs: [30, 60, 90]
+ gamma: 0.1
+
+OPTIMIZER:
+ function: 'Momentum'
+ params:
+ momentum: 0.9
+ regularizer:
+ function: 'L2'
+ factor: 0.000100
+
+TRAIN:
+ batch_size: 256
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/train_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+
+VALID:
+ batch_size: 128
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/val_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ size: [384, 384]
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
diff --git a/configs/SwinTransformer/SwinTransformer_large_patch4_window7_224.yaml b/configs/SwinTransformer/SwinTransformer_large_patch4_window7_224.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f2481da1121611f1c278115fbe28b18e41a7286c
--- /dev/null
+++ b/configs/SwinTransformer/SwinTransformer_large_patch4_window7_224.yaml
@@ -0,0 +1,74 @@
+mode: 'train'
+ARCHITECTURE:
+ name: 'SwinTransformer_large_patch4_window7_224'
+
+pretrained_model: ""
+model_save_dir: "./output/"
+classes_num: 1000
+total_images: 1281167
+save_interval: 1
+validate: True
+valid_interval: 1
+epochs: 120
+topk: 5
+image_shape: [3, 224, 224]
+
+use_mix: False
+ls_epsilon: -1
+
+LEARNING_RATE:
+ function: 'Piecewise'
+ params:
+ lr: 0.1
+ decay_epochs: [30, 60, 90]
+ gamma: 0.1
+
+OPTIMIZER:
+ function: 'Momentum'
+ params:
+ momentum: 0.9
+ regularizer:
+ function: 'L2'
+ factor: 0.000100
+
+TRAIN:
+ batch_size: 256
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/train_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+
+VALID:
+ batch_size: 64
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/val_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
diff --git a/docs/en/models/SwinTransformer_en.md b/docs/en/models/SwinTransformer_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..11d45d6c401c57f31d586b6c740d968b304c3574
--- /dev/null
+++ b/docs/en/models/SwinTransformer_en.md
@@ -0,0 +1,22 @@
+# SwinTransformer
+
+## Overview
+Swin Transformer a new vision Transformer, that capably serves as a general-purpose backbone for computer vision. It is a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. [Paper](https://arxiv.org/abs/2103.14030)。
+
+
+## Accuracy, FLOPS and Parameters
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | 0.812 | 0.955 | 4.5 | 28 |
+| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | 0.832 | 0.962 | 8.7 | 50 |
+| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | 0.835 | 0.965 | 15.4 | 88 |
+| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | 0.845 | 0.970 | 47.1 | 88 |
+| SwinTransformer_base_patch4_window7_224[1] | 0.8487 | 0.9746 | 0.852 | 0.975 | 15.4 | 88 |
+| SwinTransformer_base_patch4_window12_384[1] | 0.8642 | 0.9807 | 0.864 | 0.980 | 47.1 | 88 |
+| SwinTransformer_large_patch4_window7_224[1] | 0.8596 | 0.9783 | 0.863 | 0.979 | 34.5 | 197 |
+| SwinTransformer_large_patch4_window12_384[1] | 0.8719 | 0.9823 | 0.873 | 0.982 | 103.9 | 197 |
+
+[1]: Based on imagenet22k dataset pre-training, and then in imagenet1k dataset transfer learning.
+
+**Note**: The difference of precision with reference from the difference of data preprocessing.
diff --git a/docs/en/models/Transformer_en.md b/docs/en/models/ViT_and_DeiT_en.md
similarity index 83%
rename from docs/en/models/Transformer_en.md
rename to docs/en/models/ViT_and_DeiT_en.md
index 13d00fca7e9414a88009e0e324993bbcf1c9d908..ac275d9b4a5e9c653d0bd30c1a322505440f441c 100644
--- a/docs/en/models/Transformer_en.md
+++ b/docs/en/models/ViT_and_DeiT_en.md
@@ -9,27 +9,27 @@ DeiT(Data-efficient Image Transformers) series models were proposed by Facebook
## Accuracy, FLOPS and Parameters
-| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) |
-|:--:|:--:|:--:|:--:|:--:|:--:|
-| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | |
-| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | |
-| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | |
-| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | |
-| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | |
-| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | |
-| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | |
-
-
-| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) |
-|:--:|:--:|:--:|:--:|:--:|:--:|
-| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | |
-| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | |
-| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | |
-| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | |
-| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | |
-| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | |
-| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | |
-| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | |
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | |
+| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | |
+| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | |
+| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | |
+| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | |
+| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | |
+| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | |
+
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | |
+| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | |
+| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | |
+| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | |
+| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | |
+| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | |
+| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | |
+| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | |
Params, FLOPs, Inference speed and other information are coming soon.
diff --git a/docs/zh_CN/models/SwinTransformer.md b/docs/zh_CN/models/SwinTransformer.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d9561ba8e5ad95b75470e4d9af305b519780368
--- /dev/null
+++ b/docs/zh_CN/models/SwinTransformer.md
@@ -0,0 +1,22 @@
+# SwinTransformer
+
+## 概述
+Swin Transformer 是一种新的视觉Transformer网络,可以用作计算机视觉领域的通用骨干网路。SwinTransformer由移动窗口(shifted windows)表示的层次Transformer结构组成。移动窗口将自注意计算限制在非重叠的局部窗口上,同时允许跨窗口连接,从而提高了网络性能。[论文地址](https://arxiv.org/abs/2103.14030)。
+
+
+## 精度、FLOPS和参数量
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | 0.812 | 0.955 | 4.5 | 28 |
+| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | 0.832 | 0.962 | 8.7 | 50 |
+| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | 0.835 | 0.965 | 15.4 | 88 |
+| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | 0.845 | 0.970 | 47.1 | 88 |
+| SwinTransformer_base_patch4_window7_224[1] | 0.8487 | 0.9746 | 0.852 | 0.975 | 15.4 | 88 |
+| SwinTransformer_base_patch4_window12_384[1] | 0.8642 | 0.9807 | 0.864 | 0.980 | 47.1 | 88 |
+| SwinTransformer_large_patch4_window7_224[1] | 0.8596 | 0.9783 | 0.863 | 0.979 | 34.5 | 197 |
+| SwinTransformer_large_patch4_window12_384[1] | 0.8719 | 0.9823 | 0.873 | 0.982 | 103.9 | 197 |
+
+[1]:基于ImageNet22k数据集预训练,然后在ImageNet1k数据集迁移学习得到。
+
+**注**:与Reference的精度差异源于数据预处理不同。
diff --git a/docs/zh_CN/models/Transformer.md b/docs/zh_CN/models/ViT_and_DeiT.md
similarity index 83%
rename from docs/zh_CN/models/Transformer.md
rename to docs/zh_CN/models/ViT_and_DeiT.md
index de037177acf3987cfb60c60e68a64ef15b82ec15..d14491d2a2b44d16d5fba4c316afb757543f21ac 100644
--- a/docs/zh_CN/models/Transformer.md
+++ b/docs/zh_CN/models/ViT_and_DeiT.md
@@ -11,26 +11,26 @@ DeiT(Data-efficient Image Transformers)系列模型是由FaceBook在2020年
## 精度、FLOPS和参数量
-| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) |
-|:--:|:--:|:--:|:--:|:--:|:--:|
-| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | |
-| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | |
-| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | |
-| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | |
-| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | |
-| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | |
-| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | |
-
-
-| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) |
-|:--:|:--:|:--:|:--:|:--:|:--:|
-| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | |
-| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | |
-| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | |
-| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | |
-| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | |
-| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | |
-| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | |
-| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | |
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | |
+| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | |
+| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | |
+| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | |
+| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | |
+| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | |
+| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | |
+
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Params
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | |
+| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | |
+| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | |
+| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | |
+| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | |
+| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | |
+| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | |
+| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | |
关于Params、FLOPs、Inference speed等信息,敬请期待。
diff --git a/paddleclas.py b/paddleclas.py
index c4314ec6db2b0b56f84f00c48f0191dac635a15c..d8fcf9368473f11b5379a80a7d20bc4290677c0f 100644
--- a/paddleclas.py
+++ b/paddleclas.py
@@ -32,7 +32,10 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import argparse
import shutil
+import textwrap
+from difflib import SequenceMatcher
+from prettytable import PrettyTable
import cv2
import numpy as np
import tarfile
@@ -45,57 +48,148 @@ __all__ = ['PaddleClas']
BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, 'inference_model')
BASE_IMAGES_DIR = os.path.join(BASE_DIR, 'images')
-
-model_names = {
- 'Xception71', 'SE_ResNeXt101_32x4d', 'ShuffleNetV2_x0_5', 'ResNet34',
- 'ShuffleNetV2_x2_0', 'ResNeXt101_32x4d', 'HRNet_W48_C_ssld',
- 'ResNeSt50_fast_1s1x64d', 'MobileNetV2_x2_0', 'MobileNetV3_large_x1_0',
- 'Fix_ResNeXt101_32x48d_wsl', 'MobileNetV2_ssld', 'ResNeXt101_vd_64x4d',
- 'ResNet34_vd_ssld', 'MobileNetV3_small_x1_0', 'VGG11',
- 'ResNeXt50_vd_32x4d', 'MobileNetV3_large_x1_25',
- 'MobileNetV3_large_x1_0_ssld', 'MobileNetV2_x0_75',
- 'MobileNetV3_small_x0_35', 'MobileNetV1_x0_75', 'MobileNetV1_ssld',
- 'ResNeXt50_32x4d', 'GhostNet_x1_3_ssld', 'Res2Net101_vd_26w_4s',
- 'ResNet152', 'Xception65', 'EfficientNetB0', 'ResNet152_vd', 'HRNet_W18_C',
- 'Res2Net50_14w_8s', 'ShuffleNetV2_x0_25', 'HRNet_W64_C',
- 'Res2Net50_vd_26w_4s_ssld', 'HRNet_W18_C_ssld', 'ResNet18_vd',
- 'ResNeXt101_32x16d_wsl', 'SE_ResNeXt50_32x4d', 'SqueezeNet1_1',
- 'SENet154_vd', 'SqueezeNet1_0', 'GhostNet_x1_0', 'ResNet50_vc', 'DPN98',
- 'HRNet_W48_C', 'DenseNet264', 'SE_ResNet34_vd', 'HRNet_W44_C',
- 'MobileNetV3_small_x1_25', 'MobileNetV1_x0_5', 'ResNet200_vd', 'VGG13',
- 'EfficientNetB3', 'EfficientNetB2', 'ShuffleNetV2_x0_33',
- 'MobileNetV3_small_x0_75', 'ResNeXt152_vd_32x4d', 'ResNeXt101_32x32d_wsl',
- 'ResNet18', 'MobileNetV3_large_x0_35', 'Res2Net50_26w_4s',
- 'MobileNetV2_x0_5', 'EfficientNetB0_small', 'ResNet101_vd_ssld',
- 'EfficientNetB6', 'EfficientNetB1', 'EfficientNetB7', 'ResNeSt50',
- 'ShuffleNetV2_x1_0', 'MobileNetV3_small_x1_0_ssld', 'InceptionV4',
- 'GhostNet_x0_5', 'SE_HRNet_W64_C_ssld', 'ResNet50_ACNet_deploy',
- 'Xception41', 'ResNet50', 'Res2Net200_vd_26w_4s_ssld',
- 'Xception41_deeplab', 'SE_ResNet18_vd', 'SE_ResNeXt50_vd_32x4d',
- 'HRNet_W30_C', 'HRNet_W40_C', 'VGG19', 'Res2Net200_vd_26w_4s',
- 'ResNeXt101_32x8d_wsl', 'ResNet50_vd', 'ResNeXt152_64x4d', 'DarkNet53',
- 'ResNet50_vd_ssld', 'ResNeXt101_64x4d', 'MobileNetV1_x0_25',
- 'Xception65_deeplab', 'AlexNet', 'ResNet101', 'DenseNet121',
- 'ResNet50_vd_v2', 'Res2Net50_vd_26w_4s', 'ResNeXt101_32x48d_wsl',
- 'MobileNetV3_large_x0_5', 'MobileNetV2_x0_25', 'DPN92', 'ResNet101_vd',
- 'MobileNetV2_x1_5', 'DPN131', 'ResNeXt50_vd_64x4d', 'ShuffleNetV2_x1_5',
- 'ResNet34_vd', 'MobileNetV1', 'ResNeXt152_vd_64x4d', 'DPN107', 'VGG16',
- 'ResNeXt50_64x4d', 'RegNetX_4GF', 'DenseNet161', 'GhostNet_x1_3',
- 'HRNet_W32_C', 'Fix_ResNet50_vd_ssld_v2', 'Res2Net101_vd_26w_4s_ssld',
- 'DenseNet201', 'DPN68', 'EfficientNetB4', 'ResNeXt152_32x4d',
- 'InceptionV3', 'ShuffleNetV2_swish', 'GoogLeNet', 'ResNet50_vd_ssld_v2',
- 'SE_ResNet50_vd', 'MobileNetV2', 'ResNeXt101_vd_32x4d',
- 'MobileNetV3_large_x0_75', 'MobileNetV3_small_x0_5', 'DenseNet169',
- 'EfficientNetB5', 'DeiT_base_distilled_patch16_224',
- 'DeiT_base_distilled_patch16_384', 'DeiT_base_patch16_224',
- 'DeiT_base_patch16_384', 'DeiT_small_distilled_patch16_224',
- 'DeiT_small_patch16_224', 'DeiT_tiny_distilled_patch16_224',
- 'DeiT_tiny_patch16_224', 'ViT_base_patch16_224', 'ViT_base_patch16_384',
- 'ViT_base_patch32_384', 'ViT_large_patch16_224', 'ViT_large_patch16_384',
- 'ViT_large_patch32_384', 'ViT_small_patch16_224'
+model_series = {
+ "AlexNet": ["AlexNet"],
+ "DarkNet": ["DarkNet53"],
+ "DeiT": [
+ 'DeiT_base_distilled_patch16_224', 'DeiT_base_distilled_patch16_384',
+ 'DeiT_base_patch16_224', 'DeiT_base_patch16_384',
+ 'DeiT_small_distilled_patch16_224', 'DeiT_small_patch16_224',
+ 'DeiT_tiny_distilled_patch16_224', 'DeiT_tiny_patch16_224'
+ ],
+ "DenseNet": [
+ "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
+ "DenseNet264"
+ ],
+ "DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
+ "EfficientNet": [
+ "EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
+ "EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
+ "EfficientNetB6", "EfficientNetB7"
+ ],
+ "GhostNet":
+ ["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
+ "HRNet": [
+ "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
+ "HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
+ "HRNet_W48_C_ssld"
+ ],
+ "Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
+ "MobileNetV1": [
+ "MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
+ "MobileNetV1", "MobileNetV1_ssld"
+ ],
+ "MobileNetV2": [
+ "MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x0_75",
+ "MobileNetV2", "MobileNetV2_x1_5", "MobileNetV2_x2_0",
+ "MobileNetV2_ssld"
+ ],
+ "MobileNetV3": [
+ "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
+ "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
+ "MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
+ "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
+ "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
+ "MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
+ ],
+ "RegNet": ["RegNetX_4GF"],
+ "Res2Net": [
+ "Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
+ "Res2Net200_vd_26w_4s", "Res2Net101_vd_26w_4s",
+ "Res2Net50_vd_26w_4s_ssld", "Res2Net101_vd_26w_4s_ssld",
+ "Res2Net200_vd_26w_4s_ssld"
+ ],
+ "ResNeSt": ["ResNeSt50", "ResNeSt50_fast_1s1x64d"],
+ "ResNet": [
+ "ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50",
+ "ResNet50_vc", "ResNet50_vd", "ResNet50_vd_v2", "ResNet101",
+ "ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd",
+ "ResNet34_vd_ssld", "ResNet50_vd_ssld", "ResNet50_vd_ssld_v2",
+ "ResNet101_vd_ssld", "Fix_ResNet50_vd_ssld_v2", "ResNet50_ACNet_deploy"
+ ],
+ "ResNeXt": [
+ "ResNeXt50_32x4d", "ResNeXt50_vd_32x4d", "ResNeXt50_64x4d",
+ "ResNeXt50_vd_64x4d", "ResNeXt101_32x4d", "ResNeXt101_vd_32x4d",
+ "ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl",
+ "ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl",
+ "Fix_ResNeXt101_32x48d_wsl", "ResNeXt101_64x4d", "ResNeXt101_vd_64x4d",
+ "ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
+ "ResNeXt152_vd_64x4d"
+ ],
+ "SENet": [
+ "SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
+ "SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
+ "SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_32x4d"
+ ],
+ "ShuffleNetV2": [
+ "ShuffleNetV2_swish", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33",
+ "ShuffleNetV2_x0_5", "ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5",
+ "ShuffleNetV2_x2_0"
+ ],
+ "SqueezeNet": ["SqueezeNet1_0", "SqueezeNet1_1"],
+ "SwinTransformer": [
+ "SwinTransformer_large_patch4_window7_224_22kto1k",
+ "SwinTransformer_large_patch4_window12_384_22kto1k",
+ "SwinTransformer_base_patch4_window7_224_22kto1k",
+ "SwinTransformer_base_patch4_window12_384_22kto1k",
+ "SwinTransformer_base_patch4_window12_384",
+ "SwinTransformer_base_patch4_window7_224",
+ "SwinTransformer_small_patch4_window7_224",
+ "SwinTransformer_tiny_patch4_window7_224"
+ ],
+ "VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
+ "VisionTransformer": [
+ "ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
+ "ViT_large_patch16_224", "ViT_large_patch16_384",
+ "ViT_large_patch32_384", "ViT_small_patch16_224"
+ ],
+ "Xception": [
+ "Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab",
+ "Xception71"
+ ]
}
+class ModelNameError(Exception):
+ """ ModelNameError
+ """
+
+ def __init__(self, message=''):
+ super().__init__(message)
+
+
+def print_info():
+ table = PrettyTable(['Series', 'Name'])
+ for series in model_series:
+ names = textwrap.fill(" ".join(model_series[series]), width=100)
+ table.add_row([series, names])
+ print('Inference models that Paddle provides are listed as follows:')
+ print(table)
+
+
+def get_model_names():
+ model_names = []
+ for series in model_series:
+ model_names += model_series[series]
+ return model_names
+
+
+def similar_architectures(name='', names=[], thresh=0.1, topk=10):
+ """
+ inferred similar architectures
+ """
+ scores = []
+ for idx, n in enumerate(names):
+ if n.startswith('__'):
+ continue
+ score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
+ if score > thresh:
+ scores.append((idx, score))
+ scores.sort(key=lambda x: x[1], reverse=True)
+ similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
+ return similar_names
+
+
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
@@ -227,17 +321,26 @@ def parse_args(mMain=True, add_help=True):
class PaddleClas(object):
- print('Inference models that Paddle provides are listed as follows:\n\n{}'.
- format(model_names), '\n')
+ print_info()
def __init__(self, **kwargs):
+ model_names = get_model_names()
process_params = parse_args(mMain=False, add_help=False)
process_params.__dict__.update(**kwargs)
if not os.path.exists(process_params.model_file):
if process_params.model_name is None:
- raise Exception(
+ raise ModelNameError(
'Please input model name that you want to use!')
+
+ similar_names = similar_architectures(process_params.model_name,
+ model_names)
+ model_list = ', '.join(similar_names)
+ if process_params.model_name not in similar_names:
+ err = "{} is not exist! Maybe you want: [{}]" \
+ "".format(process_params.model_name, model_list)
+ raise ModelNameError(err)
+
if process_params.model_name in model_names:
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar'.format(
process_params.model_name)
diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py
index 9865e47fbaa8d077b440f393da1115571a0bebd7..a0869259725cc28e97ce6a5af47bd0dd3620dc77 100644
--- a/ppcls/modeling/architectures/__init__.py
+++ b/ppcls/modeling/architectures/__init__.py
@@ -47,6 +47,6 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT
from .distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from .repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B3, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g2, RepVGG_B2g4, RepVGG_B3g2, RepVGG_B3g4
-from .swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384
+from .swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from .mixnet import MixNet_S, MixNet_M, MixNet_L
from .rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
diff --git a/ppcls/modeling/architectures/swin_transformer.py b/ppcls/modeling/architectures/swin_transformer.py
index 4b65ab5528c9eaa13b605902e37ca9d7fafa336b..97efbd1f341d69fd601bfb86697d4f1cb06ca022 100644
--- a/ppcls/modeling/architectures/swin_transformer.py
+++ b/ppcls/modeling/architectures/swin_transformer.py
@@ -759,3 +759,24 @@ def SwinTransformer_base_patch4_window12_384(**args):
drop_path_rate=0.5, # NOTE: do not appear in offical code
**args)
return model
+
+
+def SwinTransformer_large_patch4_window7_224(**args):
+ model = SwinTransformer(
+ embed_dim=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=7,
+ **args)
+ return model
+
+
+def SwinTransformer_large_patch4_window12_384(**args):
+ model = SwinTransformer(
+ img_size=384,
+ embed_dim=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=12,
+ **args)
+ return model
diff --git a/requirements.txt b/requirements.txt
index 73d5e78bceb7b8712d2db6c01cbfd935c8857cb4..ec8806def5a58b45507076e820b0fad02025a11c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+prettytable
ujson
opencv-python==4.1.2.30
pillow