提交 44b9963b 编写于 作者: C cuicheng01

add mobilevit code

上级 91b5337f
......@@ -29,7 +29,8 @@
- [22. TNT series](#22)
- [23. CSwinTransformer series](#23)
- [24. PVTV2 series](#24)
- [25. Other models](#25)
- [25. MobileViT series](#25)
- [26. Other models](#26)
- [Reference](#reference)
<a name="1"></a>
......@@ -532,10 +533,21 @@ The accuracy and speed indicators of PVTV2 series models are shown in the follow
| PVT_V2_B4 | 0.836 | 0.967 | - | - | - | 9.8 | 62.6 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/PVT_V2_B4_pretrained.pdparams) | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PVT_V2_B4_infer.tar) |
| PVT_V2_B5 | 0.837 | 0.966 | - | - | - | 11.4 | 82.0 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/PVT_V2_B5_pretrained.pdparams) | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PVT_V2_B5_infer.tar) |
<a name="25"></a>
## 25. Other models
## 25. MobileViT series <sup>[[42](#ref42)]</sup>
The accuracy and speed indicators of MobileViT series models are shown in the following table. For more introduction, please refer to:[MobileViT series model documents](../models/MobileViT_en.md)
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | time(ms)<br/>bs=8 | FLOPs(M) | Params(M) | Pretrained Model Download Address | Inference Model Download Address |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| MobileViT_XXS | 0.6867 | 0.8878 | - | - | - | 1849.35 | 5.59 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XXS_pretrained.pdparams) | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileViT_XXS_infer.tar) |
| MobileViT_XS | 0.7454 | 0.9227 | - | - | - | 930.75 | 2.33 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XS_pretrained.pdparams) | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileViT_XS_infer.tar) |
| MobileViT_S | 0.7814 | 0.9413 | - | - | - | 337.24 | 1.28 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_S_pretrained.pdparams) | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileViT_S_infer.tar) |
<a name="26"></a>
## 26. Other models
The accuracy and speed indicators of AlexNet <sup>[[18](#ref18)]</sup>, SqueezeNet series <sup>[[19](#ref19)]</sup>, VGG series <sup>[[20](#ref20)]</sup>, DarkNet53 <sup>[[21](#ref21)]</sup> and other models are shown in the following table. For more information, please refer to: [Other model documents](../models/Others_en.md).
......@@ -637,3 +649,5 @@ TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE.
<a name="ref40">[40]</a>Xiaoyi Dong, Jianmin Bao, Dongdong Chen, Weiming Zhang, Nenghai Yu, Lu Yuan, Dong Chen, Baining Guo. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows.
<a name="ref41">[41]</a>Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao. PVTv2: Improved Baselines with Pyramid Vision Transformer.
<a name="ref42">[42]</a>Sachin Mehta, Mohammad Rastegari. MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer.
# MobileviT
---
## Catalogue
* [1. Overview](#1)
* [2. Accuracy, FLOPs and Parameters](#2)
<a name='1'></a>
## 1. Overview
MobileViT is a lightweight visual Transformer network that can be used as a general backbone network in the field of computer vision. MobileViT combines the advantages of CNN and Transformer, which can better deal with global features and local features, and better solve the problem of lack of inductive bias in Transformer models.
, and finally, under the same amount of parameters, compared with other SOTA models, the tasks of image classification, object detection, and semantic segmentation have been greatly improved. [Paper](https://arxiv.org/pdf/2110.02178.pdf)
<a name='2'></a>
## 2. Accuracy, FLOPs and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPs<br>(M) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileViT_XXS | 0.6867 | 0.8878 | 0.690 | - | 1849.35 | 5.59 |
| MobileViT_XS | 0.7454 | 0.9227 | 0.747 | - | 930.75 | 2.33 |
| MobileViT_S | 0.7814 | 0.9413 | 0.783 | - | 337.24 | 1.28 |
......@@ -32,7 +32,8 @@
- [22. TNT 系列](#22)
- [23. CSwinTransformer 系列](#23)
- [24. PVTV2 系列](#24)
- [25. 其他模型](#25)
- [25. MobileViT 系列](#25)
- [26. 其他模型](#26)
- [参考文献](#reference)
<a name="1"></a>
......@@ -533,10 +534,21 @@ ViT(Vision Transformer) 与 DeiT(Data-efficient Image Transformers)系列模
| PVT_V2_B5 | 0.837 | 0.966 | - | - | - | 11.4 | 82.0 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/PVT_V2_B5_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PVT_V2_B5_infer.tar) |
<a name="25"></a>
## 25. 其他模型
## 25. MobileViT 系列 <sup>[[42](#ref42)]</sup>
关于 MobileViT 系列模型的精度、速度指标如下表所示,更多介绍可以参考:[MobileViT 系列模型文档](../models/MobileViT.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | time(ms)<br/>bs=8 | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| MobileViT_XXS | 0.6867 | 0.8878 | - | - | - | 1849.35 | 5.59 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XXS_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileViT_XXS_infer.tar) |
| MobileViT_XS | 0.7454 | 0.9227 | - | - | - | 930.75 | 2.33 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XS_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileViT_XS_infer.tar) |
| MobileViT_S | 0.7814 | 0.9413 | - | - | - | 337.24 | 1.28 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_S_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileViT_S_infer.tar) |
<a name="26"></a>
## 26. 其他模型
关于 AlexNet <sup>[[18](#ref18)]</sup>、SqueezeNet 系列 <sup>[[19](#ref19)]</sup>、VGG 系列 <sup>[[20](#ref20)]</sup>、DarkNet53 <sup>[[21](#ref21)]</sup> 等模型的精度、速度指标如下表所示,更多介绍可以参考:[其他模型文档](../models/Others.md)
......@@ -637,3 +649,5 @@ TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE.
<a name="ref40">[40]</a>Xiaoyi Dong, Jianmin Bao, Dongdong Chen, Weiming Zhang, Nenghai Yu, Lu Yuan, Dong Chen, Baining Guo. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows.
<a name="ref41">[41]</a>Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao. PVTv2: Improved Baselines with Pyramid Vision Transformer.
<a name="ref42">[42]</a>Sachin Mehta, Mohammad Rastegari. MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer.
# MobileviT
---
## 目录
* [1. 概述](#1)
* [2. 精度、FLOPs 和参数量](#2)
<a name='1'></a>
## 1. 概述
MobileViT 是一个轻量级的视觉 Transformer 网络,可以用作计算机视觉领域的通用骨干网路。 MobileViT 结合了 CNN 和 Transformer 的优势,可以更好的处理全局特征和局部特征,更好地解决 Transformer 模型缺乏归纳偏置的问题,最终,在同样参数量下,与其他 SOTA 模型相比,在图像分类、目标检测、语义分割任务上都有大幅提升。[论文地址](https://arxiv.org/pdf/2110.02178.pdf)
<a name='2'></a>
## 2. 精度、FLOPs 和参数量
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPs<br>(M) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileViT_XXS | 0.6867 | 0.8878 | 0.690 | - | 1849.35 | 5.59 |
| MobileViT_XS | 0.7454 | 0.9227 | 0.747 | - | 930.75 | 2.33 |
| MobileViT_S | 0.7814 | 0.9413 | 0.783 | - | 337.24 | 1.28 |
......@@ -62,6 +62,7 @@ from ppcls.arch.backbone.model_zoo.tnt import TNT_small
from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet39_ds, HarDNet68_ds
from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53
from ppcls.arch.backbone.model_zoo.pvt_v2 import PVT_V2_B0, PVT_V2_B1, PVT_V2_B2_Linear, PVT_V2_B2, PVT_V2_B3, PVT_V2_B4, PVT_V2_B5
from ppcls.arch.backbone.model_zoo.mobilevit import MobileViT_XXS, MobileViT_XS, MobileViT_S
from ppcls.arch.backbone.model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g4, RepVGG_B3g4
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
......
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingUniform, TruncatedNormal, Constant
import math
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"MobileViT_XXS":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XXS_pretrained.pdparams",
"MobileViT_XS":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XS_pretrained.pdparams",
"MobileViT_S":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_S_pretrained.pdparams",
}
def _init_weights_linear():
weight_attr = ParamAttr(initializer=TruncatedNormal(std=.02))
bias_attr = ParamAttr(initializer=Constant(0.0))
return weight_attr, bias_attr
def _init_weights_layernorm():
weight_attr = ParamAttr(initializer=Constant(1.0))
bias_attr = ParamAttr(initializer=Constant(0.0))
return weight_attr, bias_attr
class ConvBnAct(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1):
super().__init__()
self.in_channels = in_channels
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(initializer=KaimingUniform()),
bias_attr=bias_attr)
self.norm = nn.BatchNorm2D(out_channels)
self.act = nn.Silu()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class Identity(nn.Layer):
""" Identity layer"""
def __init__(self):
super().__init__()
def forward(self, inputs):
return inputs
class Mlp(nn.Layer):
def __init__(self, embed_dim, mlp_ratio, dropout=0.1):
super().__init__()
w_attr_1, b_attr_1 = _init_weights_linear()
self.fc1 = nn.Linear(
embed_dim,
int(embed_dim * mlp_ratio),
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = _init_weights_linear()
self.fc2 = nn.Linear(
int(embed_dim * mlp_ratio),
embed_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.Silu()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.dropout2(x)
return x
class Attention(nn.Layer):
def __init__(self,
embed_dim,
num_heads,
qkv_bias=True,
dropout=0.1,
attention_dropout=0.):
super().__init__()
self.num_heads = num_heads
self.attn_head_dim = int(embed_dim / self.num_heads)
self.all_head_dim = self.attn_head_dim * self.num_heads
w_attr_1, b_attr_1 = _init_weights_linear()
self.qkv = nn.Linear(
embed_dim,
self.all_head_dim * 3,
weight_attr=w_attr_1,
bias_attr=b_attr_1 if qkv_bias else False)
self.scales = self.attn_head_dim**-0.5
w_attr_2, b_attr_2 = _init_weights_linear()
self.proj = nn.Linear(
embed_dim, embed_dim, weight_attr=w_attr_2, bias_attr=b_attr_2)
self.attn_dropout = nn.Dropout(attention_dropout)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
def transpose_multihead(self, x):
B, P, N, d = x.shape
x = x.reshape([B, P, N, self.num_heads, d // self.num_heads])
x = x.transpose([0, 1, 3, 2, 4])
return x
def forward(self, x):
b_sz, n_patches, in_channels = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape([
b_sz, n_patches, 3, self.num_heads,
qkv.shape[-1] // self.num_heads // 3
])
qkv = qkv.transpose([0, 3, 2, 1, 4])
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
query = query * self.scales
key = key.transpose([0, 1, 3, 2])
# QK^T
attn = paddle.matmul(query, key)
attn = self.softmax(attn)
attn = self.attn_dropout(attn)
# weighted sum
out = paddle.matmul(attn, value)
out = out.transpose([0, 2, 1, 3]).reshape(
[b_sz, n_patches, out.shape[1] * out.shape[3]])
out = self.proj(out)
out = self.proj_dropout(out)
return out
class EncoderLayer(nn.Layer):
def __init__(self,
embed_dim,
num_heads=4,
qkv_bias=True,
mlp_ratio=2.0,
dropout=0.1,
attention_dropout=0.,
droppath=0.):
super().__init__()
w_attr_1, b_attr_1 = _init_weights_layernorm()
w_attr_2, b_attr_2 = _init_weights_layernorm()
self.attn_norm = nn.LayerNorm(
embed_dim, weight_attr=w_attr_1, bias_attr=b_attr_1)
self.attn = Attention(embed_dim, num_heads, qkv_bias, dropout,
attention_dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
self.mlp_norm = nn.LayerNorm(
embed_dim, weight_attr=w_attr_2, bias_attr=b_attr_2)
self.mlp = Mlp(embed_dim, mlp_ratio, dropout)
def forward(self, x):
h = x
x = self.attn_norm(x)
x = self.attn(x)
x = self.drop_path(x)
x = h + x
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = self.drop_path(x)
x = x + h
return x
class Transformer(nn.Layer):
"""Transformer block for MobileViTBlock"""
def __init__(self,
embed_dim,
num_heads,
depth,
qkv_bias=True,
mlp_ratio=2.0,
dropout=0.1,
attention_dropout=0.,
droppath=0.):
super().__init__()
depth_decay = [x.item() for x in paddle.linspace(0, droppath, depth)]
layer_list = []
for i in range(depth):
layer_list.append(
EncoderLayer(embed_dim, num_heads, qkv_bias, mlp_ratio,
dropout, attention_dropout, droppath))
self.layers = nn.LayerList(layer_list)
w_attr_1, b_attr_1 = _init_weights_layernorm()
self.norm = nn.LayerNorm(
embed_dim, weight_attr=w_attr_1, bias_attr=b_attr_1, epsilon=1e-6)
def forward(self, x):
for layer in self.layers:
x = layer(x)
out = self.norm(x)
return out
class MobileV2Block(nn.Layer):
"""Mobilenet v2 InvertedResidual block"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expansion))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expansion != 1:
layers.append(ConvBnAct(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBnAct(
hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
padding=1),
# pw-linear
nn.Conv2D(
hidden_dim, oup, 1, 1, 0, bias_attr=False),
nn.BatchNorm2D(oup),
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
return self.conv(x)
class MobileViTBlock(nn.Layer):
""" MobileViTBlock for MobileViT"""
def __init__(self,
dim,
hidden_dim,
depth,
num_heads=4,
qkv_bias=True,
mlp_ratio=2.0,
dropout=0.1,
attention_dropout=0.,
droppath=0.0,
patch_size=(2, 2)):
super().__init__()
self.patch_h, self.patch_w = patch_size
# local representations
self.conv1 = ConvBnAct(dim, dim, padding=1)
self.conv2 = nn.Conv2D(
dim, hidden_dim, kernel_size=1, stride=1, bias_attr=False)
# global representations
self.transformer = Transformer(
embed_dim=hidden_dim,
num_heads=num_heads,
depth=depth,
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
droppath=droppath)
# fusion
self.conv3 = ConvBnAct(hidden_dim, dim, kernel_size=1)
self.conv4 = ConvBnAct(2 * dim, dim, padding=1)
def forward(self, x):
h = x
x = self.conv1(x)
x = self.conv2(x)
patch_h = self.patch_h
patch_w = self.patch_w
patch_area = int(patch_w * patch_h)
_, in_channels, orig_h, orig_w = x.shape
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
interpolate = False
if new_w != orig_w or new_h != orig_h:
x = F.interpolate(x, size=[new_h, new_w], mode="bilinear")
interpolate = True
num_patch_w, num_patch_h = new_w // patch_w, new_h // patch_h
num_patches = num_patch_h * num_patch_w
reshaped_x = x.reshape([-1, patch_h, num_patch_w, patch_w])
transposed_x = reshaped_x.transpose([0, 2, 1, 3])
reshaped_x = transposed_x.reshape(
[-1, in_channels, num_patches, patch_area])
transposed_x = reshaped_x.transpose([0, 3, 2, 1])
x = transposed_x.reshape([-1, num_patches, in_channels])
x = self.transformer(x)
x = x.reshape([-1, patch_h * patch_w, num_patches, in_channels])
_, pixels, num_patches, channels = x.shape
x = x.transpose([0, 3, 2, 1])
x = x.reshape([-1, num_patch_w, patch_h, patch_w])
x = x.transpose([0, 2, 1, 3])
x = x.reshape(
[-1, channels, num_patch_h * patch_h, num_patch_w * patch_w])
if interpolate:
x = F.interpolate(x, size=[orig_h, orig_w])
x = self.conv3(x)
x = paddle.concat((h, x), axis=1)
x = self.conv4(x)
return x
class MobileViT(nn.Layer):
""" MobileViT
A PaddlePaddle impl of : `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` -
https://arxiv.org/abs/2110.02178
"""
def __init__(self,
in_channels=3,
dims=[16, 32, 48, 48, 48, 64, 80, 96, 384],
hidden_dims=[96, 120, 144],
mv2_expansion=4,
class_num=1000):
super().__init__()
self.conv3x3 = ConvBnAct(
in_channels, dims[0], kernel_size=3, stride=2, padding=1)
self.mv2_block_1 = MobileV2Block(
dims[0], dims[1], expansion=mv2_expansion)
self.mv2_block_2 = MobileV2Block(
dims[1], dims[2], stride=2, expansion=mv2_expansion)
self.mv2_block_3 = MobileV2Block(
dims[2], dims[3], expansion=mv2_expansion)
self.mv2_block_4 = MobileV2Block(
dims[3], dims[4], expansion=mv2_expansion)
self.mv2_block_5 = MobileV2Block(
dims[4], dims[5], stride=2, expansion=mv2_expansion)
self.mvit_block_1 = MobileViTBlock(dims[5], hidden_dims[0], depth=2)
self.mv2_block_6 = MobileV2Block(
dims[5], dims[6], stride=2, expansion=mv2_expansion)
self.mvit_block_2 = MobileViTBlock(dims[6], hidden_dims[1], depth=4)
self.mv2_block_7 = MobileV2Block(
dims[6], dims[7], stride=2, expansion=mv2_expansion)
self.mvit_block_3 = MobileViTBlock(dims[7], hidden_dims[2], depth=3)
self.conv1x1 = ConvBnAct(dims[7], dims[8], kernel_size=1)
self.pool = nn.AdaptiveAvgPool2D(1)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(dims[8], class_num)
def forward(self, x):
x = self.conv3x3(x)
x = self.mv2_block_1(x)
x = self.mv2_block_2(x)
x = self.mv2_block_3(x)
x = self.mv2_block_4(x)
x = self.mv2_block_5(x)
x = self.mvit_block_1(x)
x = self.mv2_block_6(x)
x = self.mvit_block_2(x)
x = self.mv2_block_7(x)
x = self.mvit_block_3(x)
x = self.conv1x1(x)
x = self.pool(x)
x = x.reshape(x.shape[:2])
x = self.dropout(x)
x = self.linear(x)
return x
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def MobileViT_XXS(pretrained=False, use_ssld=False, **kwargs):
model = MobileViT(
in_channels=3,
dims=[16, 16, 24, 24, 24, 48, 64, 80, 320],
hidden_dims=[64, 80, 96],
mv2_expansion=2,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["MobileViT_XXS"], use_ssld=use_ssld)
return model
def MobileViT_XS(pretrained=False, use_ssld=False, **kwargs):
model = MobileViT(
in_channels=3,
dims=[16, 32, 48, 48, 48, 64, 80, 96, 384],
hidden_dims=[96, 120, 144],
mv2_expansion=4,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["MobileViT_XS"], use_ssld=use_ssld)
return model
def MobileViT_S(pretrained=False, use_ssld=False, **kwargs):
model = MobileViT(
in_channels=3,
dims=[16, 32, 64, 64, 64, 96, 128, 160, 640],
hidden_dims=[144, 192, 240],
mv2_expansion=4,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["MobileViT_S"], use_ssld=use_ssld)
return model
if __name__ == "__main__":
model = MobileViT_XXS()
paddle.flops(model, [1, 3, 256, 256])
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 256]
save_inference_dir: ./inference
use_dali: False
# model architecture
Arch:
name: MobileViT_S
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.01
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 0.002
eta_min: 0.0002
warmup_epoch: 5
warmup_start_lr: 0.0002
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 256
interpolation: bilinear
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: False
channel_first: False
- ResizeImage:
resize_short: 292
interpolation: bilinear
backend: pil
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 292
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 256]
save_inference_dir: ./inference
use_dali: False
# model architecture
Arch:
name: MobileViT_XS
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.01
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 0.002
eta_min: 0.0002
warmup_epoch: 5
warmup_start_lr: 0.0002
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 256
interpolation: bilinear
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: False
channel_first: False
- ResizeImage:
resize_short: 292
interpolation: bilinear
backend: pil
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 292
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 256]
save_inference_dir: ./inference
use_dali: False
# model architecture
Arch:
name: MobileViT_XXS
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.01
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 0.002
eta_min: 0.0002
warmup_epoch: 5
warmup_start_lr: 0.0002
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 256
interpolation: bilinear
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: False
channel_first: False
- ResizeImage:
resize_short: 292
interpolation: bilinear
backend: pil
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 292
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
===========================train_params===========================
model_name:MobileViT_S
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_S.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_S.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_S.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_S_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=292 -o PreProcess.transform_ops.1.CropImage.size=256 -o PreProcess.transform_ops.2.NormalizeImage.mean=[0.,0.,0.] -o PreProcess.transform_ops.2.NormalizeImage.std=[1.,1.,1.]
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================train_params===========================
model_name:MobileViT_XS
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_XS.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_XS.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_XS.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XS_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=292 -o PreProcess.transform_ops.1.CropImage.size=256 -o PreProcess.transform_ops.2.NormalizeImage.mean=[0.,0.,0.] -o PreProcess.transform_ops.2.NormalizeImage.std=[1.,1.,1.]
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================train_params===========================
model_name:MobileViT_XXS
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_XXS.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_XXS.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/MobileViT/MobileViT_XXS.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViT_XXS_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=292 -o PreProcess.transform_ops.1.CropImage.size=256 -o PreProcess.transform_ops.2.NormalizeImage.mean=[0.,0.,0.] -o PreProcess.transform_ops.2.NormalizeImage.std=[1.,1.,1.]
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册