未验证 提交 e5f75295 编写于 作者: T Tingquan Gao 提交者: GitHub

Add SwinTransformer (#721)

* Add SwinTransformer_large

* Add SwinTransformer

* Optimize user experience

* Add the doc of SwinTransformer

* Modify the format of the corner mark
上级 fb3e0233
......@@ -7,6 +7,8 @@
PaddleClas is a toolset for image classification tasks prepared for the industry and academia. It helps users train better computer vision models and apply them in real scenarios.
**Recent update**
- 2021.05.14 Add `SwinTransformer` series pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 87.19%.
- 2021.04.15 Add `MixNet` and `ReXNet` pretrained models, `MixNet_L`'s Top-1 Acc on ImageNet-1k reaches 78.6% and `ReXNet_3_0` reaches 82.09%.
- 2021.03.02 Add support for model quantization.
- 2021.02.01 Add `RepVGG` pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 79.65%.
......@@ -63,10 +65,11 @@ PaddleClas is a toolset for image classification tasks prepared for the industry
- [Inception series](#Inception_series)
- [EfficientNet and ResNeXt101_wsl series](#EfficientNet_and_ResNeXt101_wsl_series)
- [ResNeSt and RegNet series](#ResNeSt_and_RegNet_series)
- [Transformer series](#Transformer)
- [ViT and DeiT series](#ViT_and_DeiT)
- [RepVGG series](#RepVGG)
- [MixNet series](#MixNet)
- [ReXNet series](#ReXNet)
- [SwinTransformer series](#SwinTransformer)
- [Others](#Others)
- HS-ResNet: arxiv link: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf). Code and models are coming soon!
- Model training/evaluation
......@@ -353,10 +356,10 @@ Accuracy and inference time metrics of ResNeSt and RegNet series models are show
| RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) |
<a name="Transformer"></a>
### Transformer series
<a name="ViT_and_DeiT"></a>
### ViT and DeiT series
Accuracy and inference time metrics of ViT and DeiT series models are shown as follows. More detailed information can be refered to [Transformer series tutorial](./docs/en/models/Transformer_en.md).
Accuracy and inference time metrics of ViT and DeiT series models are shown as follows. More detailed information can be refered to [ViT and DeiT series tutorial](./docs/en/models/ViT_and_DeiT_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | Download Address |
......@@ -430,6 +433,25 @@ Accuracy and inference time metrics of ReXNet series models are shown as follows
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
<a name="SwinTransformer"></a>
### SwinTransformer
Accuracy and inference time metrics of SwinTransformer series models are shown as follows. More detailed information can be refered to [SwinTransformer series tutorial](./docs/en/models/SwinTransformer_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | Download Address |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | | | 4.5 | 28 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | | | 8.7 | 50 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | | | 15.4 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | | | 47.1 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | | | 15.4 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | | | 47.1 | 88 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | | | 34.5 | 197 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | | | 103.9 | 197 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams) |
[1]: Based on imagenet22k dataset pre-training, and then in imagenet1k dataset transfer learning.
<a name="Others"></a>
### Others
......
......@@ -8,7 +8,8 @@
**近期更新**
- 2021.04.15 添加`MixNet``ReXNet`系列模型,在ImageNet-1k上`MixNet_L` 模型Top1 Acc可达78.6%,`ReXNet_3_0`模型可达82.09%
- 2021.05.14 添加`SwinTransformer` 系列模型,在ImageNet-1k上,Top1 Acc可达87.19%
- 2021.04.15 添加`MixNet``ReXNet`系列模型,在ImageNet-1k上`MixNet_L`模型Top1 Acc可达78.6%,`ReXNet_3_0`模型可达82.09%
- 2021.03.02 添加分类模型量化方法与使用教程。
- 2021.02.01 添加`RepVGG`系列模型,在ImageNet-1k上Top-1 Acc可达79.65%。
- 2021.01.27 添加`ViT``DeiT`模型,在ImageNet-1k上,`ViT`模型Top-1 Acc可达85.13%,`DeiT`模型可达85.1%。
......@@ -65,10 +66,11 @@
- [Inception系列](#Inception系列)
- [EfficientNet与ResNeXt101_wsl系列](#EfficientNet与ResNeXt101_wsl系列)
- [ResNeSt与RegNet系列](#ResNeSt与RegNet系列)
- [Transformer系列](#Transformer系列)
- [ViT与DeiT系列](#ViT_and_DeiT系列)
- [RepVGG系列](#RepVGG系列)
- [MixNet系列](#MixNet系列)
- [ReXNet系列](#ReXNet系列)
- [SwinTransformer系列](#SwinTransformer系列)
- [其他模型](#其他模型)
- HS-ResNet: arxiv文章链接: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf)。 代码和预训练模型即将开源,敬请期待。
- 模型训练/评估
......@@ -358,10 +360,10 @@ ResNeSt与RegNet系列模型的精度、速度指标如下表所示,更多关
| RegNetX_4GF | 0.785 | 0.9416 | 6.46478 | 11.19862 | 8 | 22.1 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams) |
<a name="Transformer系列"></a>
### Transformer系列
<a name="ViT_and_DeiT系列"></a>
### ViT_and_DeiT系列
ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标如下表所示. 更多关于该系列模型的介绍可以参考: [Transformer系列模型文档](./docs/zh_CN/models/Transformer.md)
ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标如下表所示. 更多关于该系列模型的介绍可以参考: [ViT_and_DeiT系列模型文档](./docs/zh_CN/models/ViT_and_DeiT.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | 下载地址 |
......@@ -434,6 +436,25 @@ ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
<a name="SwinTransformer系列"></a>
### SwinTransformer系列
关于SwinTransformer系列模型的精度、速度指标如下表所示,更多介绍可以参考:[SwinTransformer系列模型文档](./docs/zh_CN/models/SwinTransformer.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | 下载地址 |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | | | 4.5 | 28 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | | | 8.7 | 50 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | | | 15.4 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | | | 47.1 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams) |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | | | 15.4 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | | | 47.1 | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | | | 34.5 | 197 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams) |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | | | 103.9 | 197 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams) |
[1]:基于ImageNet22k数据集预训练,然后在ImageNet1k数据集迁移学习得到。
<a name="其他模型"></a>
### 其他模型
......
mode: 'train'
ARCHITECTURE:
name: 'SwinTransformer_large_patch4_window12_384'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 384, 384]
use_mix: False
ls_epsilon: -1
LEARNING_RATE:
function: 'Piecewise'
params:
lr: 0.1
decay_epochs: [30, 60, 90]
gamma: 0.1
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000100
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [384, 384]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mode: 'train'
ARCHITECTURE:
name: 'SwinTransformer_large_patch4_window7_224'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 224, 224]
use_mix: False
ls_epsilon: -1
LEARNING_RATE:
function: 'Piecewise'
params:
lr: 0.1
decay_epochs: [30, 60, 90]
gamma: 0.1
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000100
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
# SwinTransformer
## Overview
Swin Transformer a new vision Transformer, that capably serves as a general-purpose backbone for computer vision. It is a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. [Paper](https://arxiv.org/abs/2103.14030)
## Accuracy, FLOPS and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | 0.812 | 0.955 | 4.5 | 28 |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | 0.832 | 0.962 | 8.7 | 50 |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | 0.835 | 0.965 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | 0.845 | 0.970 | 47.1 | 88 |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | 0.852 | 0.975 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | 0.864 | 0.980 | 47.1 | 88 |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | 0.863 | 0.979 | 34.5 | 197 |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | 0.873 | 0.982 | 103.9 | 197 |
[1]: Based on imagenet22k dataset pre-training, and then in imagenet1k dataset transfer learning.
**Note**: The difference of precision with reference from the difference of data preprocessing.
......@@ -9,27 +9,27 @@ DeiT(Data-efficient Image Transformers) series models were proposed by Facebook
## Accuracy, FLOPS and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) |
|:--:|:--:|:--:|:--:|:--:|:--:|
| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | |
| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | |
| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | |
| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | |
| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | |
| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | |
| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) |
|:--:|:--:|:--:|:--:|:--:|:--:|
| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | |
| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | |
| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | |
| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | |
| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | |
| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | |
| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | |
| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | |
| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | |
| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | |
| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | |
| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | |
| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | |
| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | |
| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | |
| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | |
| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | |
| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | |
| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | |
| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | |
| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | |
Params, FLOPs, Inference speed and other information are coming soon.
# SwinTransformer
## 概述
Swin Transformer 是一种新的视觉Transformer网络,可以用作计算机视觉领域的通用骨干网路。SwinTransformer由移动窗口(shifted windows)表示的层次Transformer结构组成。移动窗口将自注意计算限制在非重叠的局部窗口上,同时允许跨窗口连接,从而提高了网络性能。[论文地址](https://arxiv.org/abs/2103.14030)
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| SwinTransformer_tiny_patch4_window7_224 | 0.8069 | 0.9534 | 0.812 | 0.955 | 4.5 | 28 |
| SwinTransformer_small_patch4_window7_224 | 0.8275 | 0.9613 | 0.832 | 0.962 | 8.7 | 50 |
| SwinTransformer_base_patch4_window7_224 | 0.8300 | 0.9626 | 0.835 | 0.965 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384 | 0.8439 | 0.9693 | 0.845 | 0.970 | 47.1 | 88 |
| SwinTransformer_base_patch4_window7_224<sup>[1]</sup> | 0.8487 | 0.9746 | 0.852 | 0.975 | 15.4 | 88 |
| SwinTransformer_base_patch4_window12_384<sup>[1]</sup> | 0.8642 | 0.9807 | 0.864 | 0.980 | 47.1 | 88 |
| SwinTransformer_large_patch4_window7_224<sup>[1]</sup> | 0.8596 | 0.9783 | 0.863 | 0.979 | 34.5 | 197 |
| SwinTransformer_large_patch4_window12_384<sup>[1]</sup> | 0.8719 | 0.9823 | 0.873 | 0.982 | 103.9 | 197 |
[1]:基于ImageNet22k数据集预训练,然后在ImageNet1k数据集迁移学习得到。
**注**:与Reference的精度差异源于数据预处理不同。
......@@ -11,26 +11,26 @@ DeiT(Data-efficient Image Transformers)系列模型是由FaceBook在2020年
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) |
|:--:|:--:|:--:|:--:|:--:|:--:|
| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | |
| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | |
| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | |
| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | |
| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | |
| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | |
| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) |
|:--:|:--:|:--:|:--:|:--:|:--:|
| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | |
| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | |
| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | |
| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | |
| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | |
| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | |
| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | |
| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| ViT_small_patch16_224 | 0.7769 | 0.9342 | 0.7785 | 0.9342 | | |
| ViT_base_patch16_224 | 0.8195 | 0.9617 | 0.8178 | 0.9613 | | |
| ViT_base_patch16_384 | 0.8414 | 0.9717 | 0.8420 | 0.9722 | | |
| ViT_base_patch32_384 | 0.8176 | 0.9613 | 0.8166 | 0.9613 | | |
| ViT_large_patch16_224 | 0.8323 | 0.9650 | 0.8306 | 0.9644 | | |
| ViT_large_patch16_384 | 0.8513 | 0.9736 | 0.8517 | 0.9736 | | |
| ViT_large_patch32_384 | 0.8153 | 0.9608 | 0.815 | - | | |
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| DeiT_tiny_patch16_224 | 0.718 | 0.910 | 0.722 | 0.911 | | |
| DeiT_small_patch16_224 | 0.796 | 0.949 | 0.799 | 0.950 | | |
| DeiT_base_patch16_224 | 0.817 | 0.957 | 0.818 | 0.956 | | |
| DeiT_base_patch16_384 | 0.830 | 0.962 | 0.829 | 0.972 | | |
| DeiT_tiny_distilled_patch16_224 | 0.741 | 0.918 | 0.745 | 0.919 | | |
| DeiT_small_distilled_patch16_224 | 0.809 | 0.953 | 0.812 | 0.954 | | |
| DeiT_base_distilled_patch16_224 | 0.831 | 0.964 | 0.834 | 0.965 | | |
| DeiT_base_distilled_patch16_384 | 0.851 | 0.973 | 0.852 | 0.972 | | |
关于Params、FLOPs、Inference speed等信息,敬请期待。
......@@ -32,7 +32,10 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import argparse
import shutil
import textwrap
from difflib import SequenceMatcher
from prettytable import PrettyTable
import cv2
import numpy as np
import tarfile
......@@ -45,57 +48,148 @@ __all__ = ['PaddleClas']
BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, 'inference_model')
BASE_IMAGES_DIR = os.path.join(BASE_DIR, 'images')
model_names = {
'Xception71', 'SE_ResNeXt101_32x4d', 'ShuffleNetV2_x0_5', 'ResNet34',
'ShuffleNetV2_x2_0', 'ResNeXt101_32x4d', 'HRNet_W48_C_ssld',
'ResNeSt50_fast_1s1x64d', 'MobileNetV2_x2_0', 'MobileNetV3_large_x1_0',
'Fix_ResNeXt101_32x48d_wsl', 'MobileNetV2_ssld', 'ResNeXt101_vd_64x4d',
'ResNet34_vd_ssld', 'MobileNetV3_small_x1_0', 'VGG11',
'ResNeXt50_vd_32x4d', 'MobileNetV3_large_x1_25',
'MobileNetV3_large_x1_0_ssld', 'MobileNetV2_x0_75',
'MobileNetV3_small_x0_35', 'MobileNetV1_x0_75', 'MobileNetV1_ssld',
'ResNeXt50_32x4d', 'GhostNet_x1_3_ssld', 'Res2Net101_vd_26w_4s',
'ResNet152', 'Xception65', 'EfficientNetB0', 'ResNet152_vd', 'HRNet_W18_C',
'Res2Net50_14w_8s', 'ShuffleNetV2_x0_25', 'HRNet_W64_C',
'Res2Net50_vd_26w_4s_ssld', 'HRNet_W18_C_ssld', 'ResNet18_vd',
'ResNeXt101_32x16d_wsl', 'SE_ResNeXt50_32x4d', 'SqueezeNet1_1',
'SENet154_vd', 'SqueezeNet1_0', 'GhostNet_x1_0', 'ResNet50_vc', 'DPN98',
'HRNet_W48_C', 'DenseNet264', 'SE_ResNet34_vd', 'HRNet_W44_C',
'MobileNetV3_small_x1_25', 'MobileNetV1_x0_5', 'ResNet200_vd', 'VGG13',
'EfficientNetB3', 'EfficientNetB2', 'ShuffleNetV2_x0_33',
'MobileNetV3_small_x0_75', 'ResNeXt152_vd_32x4d', 'ResNeXt101_32x32d_wsl',
'ResNet18', 'MobileNetV3_large_x0_35', 'Res2Net50_26w_4s',
'MobileNetV2_x0_5', 'EfficientNetB0_small', 'ResNet101_vd_ssld',
'EfficientNetB6', 'EfficientNetB1', 'EfficientNetB7', 'ResNeSt50',
'ShuffleNetV2_x1_0', 'MobileNetV3_small_x1_0_ssld', 'InceptionV4',
'GhostNet_x0_5', 'SE_HRNet_W64_C_ssld', 'ResNet50_ACNet_deploy',
'Xception41', 'ResNet50', 'Res2Net200_vd_26w_4s_ssld',
'Xception41_deeplab', 'SE_ResNet18_vd', 'SE_ResNeXt50_vd_32x4d',
'HRNet_W30_C', 'HRNet_W40_C', 'VGG19', 'Res2Net200_vd_26w_4s',
'ResNeXt101_32x8d_wsl', 'ResNet50_vd', 'ResNeXt152_64x4d', 'DarkNet53',
'ResNet50_vd_ssld', 'ResNeXt101_64x4d', 'MobileNetV1_x0_25',
'Xception65_deeplab', 'AlexNet', 'ResNet101', 'DenseNet121',
'ResNet50_vd_v2', 'Res2Net50_vd_26w_4s', 'ResNeXt101_32x48d_wsl',
'MobileNetV3_large_x0_5', 'MobileNetV2_x0_25', 'DPN92', 'ResNet101_vd',
'MobileNetV2_x1_5', 'DPN131', 'ResNeXt50_vd_64x4d', 'ShuffleNetV2_x1_5',
'ResNet34_vd', 'MobileNetV1', 'ResNeXt152_vd_64x4d', 'DPN107', 'VGG16',
'ResNeXt50_64x4d', 'RegNetX_4GF', 'DenseNet161', 'GhostNet_x1_3',
'HRNet_W32_C', 'Fix_ResNet50_vd_ssld_v2', 'Res2Net101_vd_26w_4s_ssld',
'DenseNet201', 'DPN68', 'EfficientNetB4', 'ResNeXt152_32x4d',
'InceptionV3', 'ShuffleNetV2_swish', 'GoogLeNet', 'ResNet50_vd_ssld_v2',
'SE_ResNet50_vd', 'MobileNetV2', 'ResNeXt101_vd_32x4d',
'MobileNetV3_large_x0_75', 'MobileNetV3_small_x0_5', 'DenseNet169',
'EfficientNetB5', 'DeiT_base_distilled_patch16_224',
'DeiT_base_distilled_patch16_384', 'DeiT_base_patch16_224',
'DeiT_base_patch16_384', 'DeiT_small_distilled_patch16_224',
'DeiT_small_patch16_224', 'DeiT_tiny_distilled_patch16_224',
'DeiT_tiny_patch16_224', 'ViT_base_patch16_224', 'ViT_base_patch16_384',
'ViT_base_patch32_384', 'ViT_large_patch16_224', 'ViT_large_patch16_384',
'ViT_large_patch32_384', 'ViT_small_patch16_224'
model_series = {
"AlexNet": ["AlexNet"],
"DarkNet": ["DarkNet53"],
"DeiT": [
'DeiT_base_distilled_patch16_224', 'DeiT_base_distilled_patch16_384',
'DeiT_base_patch16_224', 'DeiT_base_patch16_384',
'DeiT_small_distilled_patch16_224', 'DeiT_small_patch16_224',
'DeiT_tiny_distilled_patch16_224', 'DeiT_tiny_patch16_224'
],
"DenseNet": [
"DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
"DenseNet264"
],
"DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
"EfficientNet": [
"EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
"EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
"EfficientNetB6", "EfficientNetB7"
],
"GhostNet":
["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
"HRNet": [
"HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
"HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
"HRNet_W48_C_ssld"
],
"Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
"MobileNetV1": [
"MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
"MobileNetV1", "MobileNetV1_ssld"
],
"MobileNetV2": [
"MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x0_75",
"MobileNetV2", "MobileNetV2_x1_5", "MobileNetV2_x2_0",
"MobileNetV2_ssld"
],
"MobileNetV3": [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
],
"RegNet": ["RegNetX_4GF"],
"Res2Net": [
"Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
"Res2Net200_vd_26w_4s", "Res2Net101_vd_26w_4s",
"Res2Net50_vd_26w_4s_ssld", "Res2Net101_vd_26w_4s_ssld",
"Res2Net200_vd_26w_4s_ssld"
],
"ResNeSt": ["ResNeSt50", "ResNeSt50_fast_1s1x64d"],
"ResNet": [
"ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50",
"ResNet50_vc", "ResNet50_vd", "ResNet50_vd_v2", "ResNet101",
"ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd",
"ResNet34_vd_ssld", "ResNet50_vd_ssld", "ResNet50_vd_ssld_v2",
"ResNet101_vd_ssld", "Fix_ResNet50_vd_ssld_v2", "ResNet50_ACNet_deploy"
],
"ResNeXt": [
"ResNeXt50_32x4d", "ResNeXt50_vd_32x4d", "ResNeXt50_64x4d",
"ResNeXt50_vd_64x4d", "ResNeXt101_32x4d", "ResNeXt101_vd_32x4d",
"ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl",
"ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl",
"Fix_ResNeXt101_32x48d_wsl", "ResNeXt101_64x4d", "ResNeXt101_vd_64x4d",
"ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
"ResNeXt152_vd_64x4d"
],
"SENet": [
"SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
"SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
"SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_32x4d"
],
"ShuffleNetV2": [
"ShuffleNetV2_swish", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33",
"ShuffleNetV2_x0_5", "ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5",
"ShuffleNetV2_x2_0"
],
"SqueezeNet": ["SqueezeNet1_0", "SqueezeNet1_1"],
"SwinTransformer": [
"SwinTransformer_large_patch4_window7_224_22kto1k",
"SwinTransformer_large_patch4_window12_384_22kto1k",
"SwinTransformer_base_patch4_window7_224_22kto1k",
"SwinTransformer_base_patch4_window12_384_22kto1k",
"SwinTransformer_base_patch4_window12_384",
"SwinTransformer_base_patch4_window7_224",
"SwinTransformer_small_patch4_window7_224",
"SwinTransformer_tiny_patch4_window7_224"
],
"VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
"VisionTransformer": [
"ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
"ViT_large_patch16_224", "ViT_large_patch16_384",
"ViT_large_patch32_384", "ViT_small_patch16_224"
],
"Xception": [
"Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab",
"Xception71"
]
}
class ModelNameError(Exception):
""" ModelNameError
"""
def __init__(self, message=''):
super().__init__(message)
def print_info():
table = PrettyTable(['Series', 'Name'])
for series in model_series:
names = textwrap.fill(" ".join(model_series[series]), width=100)
table.add_row([series, names])
print('Inference models that Paddle provides are listed as follows:')
print(table)
def get_model_names():
model_names = []
for series in model_series:
model_names += model_series[series]
return model_names
def similar_architectures(name='', names=[], thresh=0.1, topk=10):
"""
inferred similar architectures
"""
scores = []
for idx, n in enumerate(names):
if n.startswith('__'):
continue
score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
if score > thresh:
scores.append((idx, score))
scores.sort(key=lambda x: x[1], reverse=True)
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
return similar_names
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
......@@ -227,17 +321,26 @@ def parse_args(mMain=True, add_help=True):
class PaddleClas(object):
print('Inference models that Paddle provides are listed as follows:\n\n{}'.
format(model_names), '\n')
print_info()
def __init__(self, **kwargs):
model_names = get_model_names()
process_params = parse_args(mMain=False, add_help=False)
process_params.__dict__.update(**kwargs)
if not os.path.exists(process_params.model_file):
if process_params.model_name is None:
raise Exception(
raise ModelNameError(
'Please input model name that you want to use!')
similar_names = similar_architectures(process_params.model_name,
model_names)
model_list = ', '.join(similar_names)
if process_params.model_name not in similar_names:
err = "{} is not exist! Maybe you want: [{}]" \
"".format(process_params.model_name, model_list)
raise ModelNameError(err)
if process_params.model_name in model_names:
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar'.format(
process_params.model_name)
......
......@@ -47,6 +47,6 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT
from .distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from .repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B3, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g2, RepVGG_B2g4, RepVGG_B3g2, RepVGG_B3g4
from .swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384
from .swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from .mixnet import MixNet_S, MixNet_M, MixNet_L
from .rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
......@@ -759,3 +759,24 @@ def SwinTransformer_base_patch4_window12_384(**args):
drop_path_rate=0.5, # NOTE: do not appear in offical code
**args)
return model
def SwinTransformer_large_patch4_window7_224(**args):
model = SwinTransformer(
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
**args)
return model
def SwinTransformer_large_patch4_window12_384(**args):
model = SwinTransformer(
img_size=384,
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
**args)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册